未验证 提交 100edd84 编写于 作者: L LielinJiang 提交者: GitHub

Refine codes and docs (#61)

* refine codes docs
上级 b2cc92a5
......@@ -32,31 +32,10 @@ PaddleGAN 是一个基于飞桨的生成对抗网络开发工具包.
## 安装
### 1. 安装 paddlepaddle
PaddleGAN 所需的版本:
* PaddlePaddle >= 2.0.0-rc
* Python >= 3.5+
```
pip install -U paddlepaddle-gpu
```
### 2. 安装ppgan
```
python -m pip install 'git+https://github.com/PaddlePaddle/PaddleGAN.git'
```
或者通过将项目克隆到本地
```
git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
pip install -v -e . # or "python setup.py develop"
```
请参考[安装文档](./docs/install.md)来进行PaddlePaddle和ppgan的安装
## 数据准备
请参考 [数据准备](./docs/data_prepare.md) 来准备对应的数据.
请参考[数据准备](./docs/data_prepare.md) 来准备对应的数据.
## 快速开始
......
English | [简体中文](./README_cn.md)
English | [简体中文](./README.md)
# PaddleGAN
......@@ -35,35 +35,13 @@ changes.
## Install
### 1. install paddlepaddle
PaddleGAN work with:
* PaddlePaddle >= 2.0.0-rc
* Python >= 3.5+
```
pip install -U paddlepaddle-gpu
```
### 2. install ppgan
```
python -m pip install 'git+https://github.com/PaddlePaddle/PaddleGAN.git'
```
Or install it from a local clone
```
git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
pip install -v -e . # or "python setup.py develop"
```
Please refer to [install](./docs/install_en.md).
## Data Prepare
Please refer to [data prepare](./docs/data_prepare.md) for dataset preparation.
Please refer to [data prepare](./docs/data_prepare_en.md) for dataset preparation.
## Get Start
Please refer [get started](./docs/get_started.md) for the basic usage of PaddleGAN.
Please refer [get started](./docs/get_started_en.md) for the basic usage of PaddleGAN.
## Model tutorial
* [Pixel2Pixel and CycleGAN](./docs/tutorials/pix2pix_cyclegan.md)
......
import os
import argparse
from ppgan.utils.download import get_path_from_url
CYCLEGAN_URL_ROOT = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/'
parser = argparse.ArgumentParser(description='download datasets')
parser.add_argument('--name',
type=str,
required=True,
help='dataset name, \
support dataset name: apple2orange, summer2winter_yosemite, \
horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, \
vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, \
ae_photos, cityscapes')
if __name__ == "__main__":
args = parser.parse_args()
data_url = CYCLEGAN_URL_ROOT + args.name + '.zip'
if args.name == 'cityscapes':
data_url = 'https://paddlegan.bj.bcebos.com/datasets/cityscapes.zip'
path = get_path_from_url(data_url)
dst = os.path.join('data', args.name)
print('symlink {} to {}'.format(path, dst))
os.symlink(path, dst)
import os
import argparse
from ppgan.utils.download import get_path_from_url
PIX2PIX_URL_ROOT = 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/'
parser = argparse.ArgumentParser(description='download datasets')
parser.add_argument('--name',
type=str,
required=True,
help='dataset name, \
support dataset name: cityscapes, night2day, edges2handbags, \
edges2shoes, facades, maps')
if __name__ == "__main__":
args = parser.parse_args()
data_url = PIX2PIX_URL_ROOT + args.name + '.tar.gz'
path = get_path_from_url(data_url)
dst = os.path.join('data', args.name)
print('symlink {} to {}'.format(path, dst))
os.symlink(path, dst)
## 数据准备
推荐把数据集软链接到 `$PaddleGAN/data`. 软链接后的目录结构如下图所示:
现有的配置默认数据集的路径是在`$PaddleGAN/data`下,目录结构如下图所示。如果你已经下载好数据集了,建议将数据集软链接到 `$PaddleGAN/data`
```
PaddleGAN
......@@ -28,8 +28,65 @@ PaddleGAN
```
### cyclegan 相关的数据集下载
如果将数据集放在其他位置,比如 ```your/data/path```
你可以修改配置文件中的 ```dataroot``` 参数:
```
dataset:
train:
name: PairedDataset
dataroot: your/data/path
num_workers: 4
```
### CycleGAN模型相关的数据集下载
#### 已有的数据集下载
##### 从网页下载
cyclgan模型相关的数据集可以在[这里](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/)下载
### pix2pix 相关的数据集下载
##### 使用脚本下载
我们在 ```PaddleGAN/data``` 文件夹下提供了一个脚本 ```download_cyclegan_data.py``` 方便下载CycleGAN相关的
数据集。执行如下命令可以下载相关的数据集,目前支持的数据集名称有:apple2orange, summer2winter_yosemite,horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, cityscapes。
执行如下命令,可以下载对应的数据集到 ```~/.cache/ppgan``` 并软连接到 ```PaddleGAN/data/``` 下。
```
python data/download_cyclegan_data.py --name horse2zebra
```
#### 使用自己的数据集
如果你使用自己的数据集,需要构造成如下目录的格式。注意 ```xxxA``````xxxB```文件数量,文件内容无需一一对应。
```
custom_datasets
├── testA
├── testB
├── trainA
└── trainB
```
### Pix2Pix相关的数据集下载
#### 已有的数据集下载
##### 从网页下载
pixel2pixel模型相关的数据集可以在[这里](hhttps://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/)下载
##### 使用脚本下载
我们在 ```PaddleGAN/data``` 文件夹下提供了一个脚本 ```download_pix2pix_data.py``` 方便下载pix2pix模型相关的数据集。执行如下命令可以下载相关的数据集,目前支持的数据集名称有:apple2orange, summer2winter_yosemite,horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, cityscapes。
执行如下命令,可以下载对应的数据集到 ```~/.cache/ppgan``` 并软连接到 ```PaddleGAN/data/``` 下。
```
python data/download_pix2pix_data.py --name cityscapes
```
#### 使用自己的数据集
如果你使用自己的数据集,需要构造成如下目录的格式。同时图片应该制作成下图的样式,即左边为一种风格,另一边为相应转换的风格。
```
facades
├── test
├── train
└── val
```
![](./imgs/1.jpg)
## data prepare
It is recommended to symlink the dataset root to `$PaddleGAN/data`.
The config will suppose your data put in `$PaddleGAN/data`. You can symlink your datasets to `$PaddleGAN/data`.
```
PaddleGAN
......@@ -28,8 +28,65 @@ PaddleGAN
```
### cyclegan datasets
more dataset for cyclegan you can download from [here](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/)
if you put your datasets on other place,for example ```your/data/path```,
you can also change ```dataroot``` in config file:
### pix2pix datasets
more dataset for pix2pix you can download from [here](hhttps://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/)
```
dataset:
train:
name: PairedDataset
dataroot: your/data/path
num_workers: 4
```
### Datasets of CycleGAN
#### download existed datasets
##### download form website
datasets for CycleGAN you can download from [here](https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/)
##### download by script
You can use ```download_cyclegan_data.py``` in ```PaddleGAN/data``` to download datasets you wanted. Supported datasets are: apple2orange, summer2winter_yosemite,horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, cityscapes。
run following command. Dataset will be downloaded to ```~/.cache/ppgan``` and symlink to ```PaddleGAN/data/``` .
```
python data/download_cyclegan_data.py --name horse2zebra
```
#### custom dataset
Data should be arranged in following way if you use custom dataset.
```
custom_datasets
├── testA
├── testB
├── trainA
└── trainB
```
### Datasets of Pix2Pix
#### download existed datasets
##### download from website
dataset for pix2pix you can download from [here](hhttps://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/)
##### download by script
You can use ```download_pix2pix_data.py``` in ```PaddleGAN/data``` to download datasets you wanted. Supported datasets are: apple2orange, summer2winter_yosemite,horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, cityscapes.
run following command. Dataset will be downloaded to ```~/.cache/ppgan``` and symlink to ```PaddleGAN/data/``` .
```
python data/download_pix2pix_data.py --name cityscapes
```
#### custom datasets
Data should be arranged in following way if you use custom dataset. And image content shoubld be same with example image.
```
facades
├── test
├── train
└── val
```
![](./imgs/1.jpg)
## 快速开始使用PaddleGAN
注意:
* 开始使用PaddleGAN前请确保已经阅读过[安装文档](./install.md),并根据[数据准备文档](./data_prepare.md)准备好数据集。
* 以下教程以CycleGAN模型在Cityscapes数据集上的训练预测作为示例。
### 训练
#### 单卡训练
```
python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml
```
#### 参数
continue train from last checkpoint
- `--config-file (str)`: 配置文件的路径。
输出的日志,权重,可视化结果会默认保存在```./output_dir```中,可以通过配置文件中的```output_dir```参数修改:
```
output_dir: output_dir
```
保存的文件夹会根据模型名字和时间戳自动生成一个新目录,目录示例如下:
```
output_dir
└── CycleGANModel-2020-10-29-09-21
├── epoch_1_checkpoint.pkl
├── log.txt
└── visual_train
├── epoch001_fake_A.png
├── epoch001_fake_B.png
├── epoch001_idt_A.png
├── epoch001_idt_B.png
├── epoch001_real_A.png
├── epoch001_real_B.png
├── epoch001_rec_A.png
├── epoch001_rec_B.png
├── epoch002_fake_A.png
├── epoch002_fake_B.png
├── epoch002_idt_A.png
├── epoch002_idt_B.png
├── epoch002_real_A.png
├── epoch002_real_B.png
├── epoch002_rec_A.png
└── epoch002_rec_B.png
```
#### 恢复训练
训练过程中默认会保存上一个epoch的checkpoint,方便恢复训练
```
python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml --resume your_checkpoint_path
```
#### 参数
multiple gpus train:
- `--resume (str)`: 用来恢复训练的checkpoint路径。
#### 多卡训练:
```
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch tools/main.py --config-file configs/pix2pix_cityscapes.yaml
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch tools/main.py --config-file configs/cyclegan_cityscapes.yaml
```
### 预测
```
python tools/main.py --config-file configs/cyclegan_cityscapes.yaml --evaluate-only --load your_weight_path
```
#### 参数
- `--evaluate-only`: 是否仅进行预测。
- `--load (str)`: 训练好的权重路径。
## Getting started with PaddleGAN
Note:
* Before starting to use PaddleGAN, please make sure you have read the [install document](./install_en.md), and prepare the dataset according to the [data preparation document](./data_prepare_en.md)
* The following tutorial uses the train and evaluate of the CycleGAN model on the Cityscapes dataset as an example
### Train
#### Train with single gpu
```
python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml
```
#### Args
- `--config-file (str)`: path of config file。
The output log, weight, and visualization result will be saved in ```./output_dir``` by default, which can be modified by the ```output_dir``` parameter in the config file:
```
output_dir: output_dir
```
continue train from last checkpoint
The saved folder will automatically generate a new directory based on the model name and timestamp. The directory example is as follows:
```
output_dir
└── CycleGANModel-2020-10-29-09-21
├── epoch_1_checkpoint.pkl
├── log.txt
└── visual_train
├── epoch001_fake_A.png
├── epoch001_fake_B.png
├── epoch001_idt_A.png
├── epoch001_idt_B.png
├── epoch001_real_A.png
├── epoch001_real_B.png
├── epoch001_rec_A.png
├── epoch001_rec_B.png
├── epoch002_fake_A.png
├── epoch002_fake_B.png
├── epoch002_idt_A.png
├── epoch002_idt_B.png
├── epoch002_real_A.png
├── epoch002_real_B.png
├── epoch002_rec_A.png
└── epoch002_rec_B.png
```
#### Recovery of training
The checkpoint of the previous epoch will be saved by default during the training process to facilitate the recovery of training
```
python -u tools/main.py --config-file configs/cyclegan_cityscapes.yaml --resume your_checkpoint_path
```
#### Args
multiple gpus train:
- `--resume (str)`: path of checkpoint。
#### Train with multiple gpus:
```
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch tools/main.py --config-file configs/pix2pix_cityscapes.yaml
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch tools/main.py --config-file configs/cyclegan_cityscapes.yaml
```
### Evaluate
### evaluate
```
python tools/main.py --config-file configs/cyclegan_cityscapes.yaml --evaluate-only --load your_weight_path
```
#### Args
- `--evaluate-only`: whether to evaluate only。
- `--load (str)`: path of weight。
## 安装PaddleGAN
### 要求
* PaddlePaddle >= 2.0.0-rc
* Python >= 3.5+
* CUDA >= 9.0
### 1. 安装PaddlePaddle
```
pip install -U paddlepaddle-gpu==2.0.0rc0
```
上面命令会默认安装cuda10.2的包,如果想安装其他cuda版本的包,可以参考下面的表格。
<table class="docutils"><tbody><th width="80"> CUDA </th><th valign="bottom" align="left" width="100">python3.8</th><th valign="bottom" align="left" width="100">python3.7</th><th valign="bottom" align="left" width="100">python3.6</th> <tr><td align="left">10.1</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"> </td> </tr> <tr><td align="left">10.0</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"> </td> </tr> <tr><td align="left">9.0</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> </tr></tbody></table>
### 2. 安装ppgan
```
git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
pip install -v -e . # or "python setup.py develop"
```
按照上述方法安装成功后,本地的修改也会自动同步到ppgan中
## Install PaddleGAN
### requirements
* PaddlePaddle >= 2.0.0-rc
* Python >= 3.5+
* CUDA >= 9.0
### 1. Install PaddlePaddle
```
pip install -U paddlepaddle-gpu==2.0.0rc0
```
Note: command above will install paddle with cuda10.2,if your installed cuda is different, you can choose an proper version to install from table below.
<table class="docutils"><tbody><th width="80"> CUDA </th><th valign="bottom" align="left" width="100">python3.8</th><th valign="bottom" align="left" width="100">python3.7</th><th valign="bottom" align="left" width="100">python3.6</th> <tr><td align="left">10.1</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10.1-cudnn7-mkl_gcc8.2%2Fpaddlepaddle_gpu-2.0.0rc0.post101-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"> </td> </tr> <tr><td align="left">10.0</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda10-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post100-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"> </td> </tr> <tr><td align="left">9.0</td><td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp38-cp38-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp37-cp37m-linux_x86_64.whl
</code></pre> </details> </td> <td align="left"><details><summary> install </summary><pre><code>python -m pip install https://paddle-wheel.bj.bcebos.com/2.0.0-rc0-gpu-cuda9-cudnn7-mkl%2Fpaddlepaddle_gpu-2.0.0rc0.post90-cp36-cp36m-linux_x86_64.whl
</code></pre> </details> </td> </tr></tbody></table>
### 2. Install ppgan
```
git clone https://github.com/PaddlePaddle/PaddleGAN
cd PaddleGAN
pip install -v -e . # or "python setup.py develop"
```
......@@ -39,7 +39,7 @@ DAIN 模型通过探索深度的信息来显式检测遮挡。并且开发了一
```
ppgan.apps.DAINPredictor(
output_path='output',
output='output',
weight_path=None,
time_step=None,
use_gpu=True,
......@@ -47,7 +47,7 @@ ppgan.apps.DAINPredictor(
```
#### 参数
- `output_path (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
- `time_step (int)`: 补帧的时间系数,如果设置为0.5,则原先为每秒30帧的视频,补帧后变为每秒60帧。
- `remove_duplicates (bool,可选的)`: 是否删除重复帧,默认值:`False`.
......@@ -61,7 +61,7 @@ ppgan.apps.DeOldifyPredictor(output='output', weight_path=None, render_factor=32
```
#### 参数
- `output_path (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
- `render_factor (int)`: 会将该参数乘以16后作为输入帧的resize的值,如果该值设置为32,
则输入帧会resize到(32 * 16, 32 * 16)的尺寸再输入到网络中。
......@@ -80,7 +80,7 @@ ppgan.apps.DeepRemasterPredictor(
```
#### 参数
- `output_path (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
- `colorization (bool)`: 是否对输入视频上色,如果选项设置为 `True` ,则参考帧的文件夹路径也必须要设置。默认值:`False`
- `reference_dir (bool)`: 参考帧的文件夹路径。默认值:`None`
......@@ -96,7 +96,7 @@ ppgan.apps.RealSRPredictor(output='output', weight_path=None)
```
#### 参数
- `output_path (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
-
### 超分辨率模型EDVRPredictor
......@@ -111,5 +111,5 @@ ppgan.apps.EDVRPredictor(output='output', weight_path=None)
```
#### 参数
- `output_path (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `output (str,可选的)`: 输出的文件夹路径,默认值:`output`.
- `weight_path (None,可选的)`: 载入的权重路径,如果没有设置,则从云端下载默认的权重到本地。默认值:`None`
......@@ -22,7 +22,7 @@ from imageio import imread, imsave
import paddle
import paddle.fluid as fluid
from paddle.utils.download import get_path_from_url
from ppgan.utils.download import get_path_from_url
from ppgan.utils.video import video2frames, frames2video
from .base_predictor import BasePredictor
......@@ -39,8 +39,7 @@ class DAINPredictor(BasePredictor):
remove_duplicates=False):
self.output_path = os.path.join(output, 'DAIN')
if weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__))
weight_path = get_path_from_url(DAIN_WEIGHT_URL, cur_path)
weight_path = get_path_from_url(DAIN_WEIGHT_URL)
self.weight_path = weight_path
self.time_step = time_step
......
......@@ -22,7 +22,7 @@ from skimage import color
import paddle
from ppgan.models.generators.remaster import NetworkR, NetworkC
from paddle.utils.download import get_path_from_url
from ppgan.utils.download import get_path_from_url
from .base_predictor import BasePredictor
DEEPREMASTER_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
......@@ -77,9 +77,7 @@ class DeepRemasterPredictor(BasePredictor):
self.mindim = mindim
if weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__))
weight_path = get_path_from_url(DEEPREMASTER_WEIGHT_URL, cur_path)
print(weight_path)
weight_path = get_path_from_url(DEEPREMASTER_WEIGHT_URL)
self.weight_path = weight_path
......
......@@ -20,7 +20,7 @@ from PIL import Image
from tqdm import tqdm
import paddle
from paddle.utils.download import get_path_from_url
from ppgan.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames
from ppgan.models.generators.deoldify import build_model
......@@ -36,8 +36,7 @@ class DeOldifyPredictor(BasePredictor):
self.render_factor = render_factor
self.model = build_model()
if weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__))
weight_path = get_path_from_url(DEOLDIFY_WEIGHT_URL, cur_path)
weight_path = get_path_from_url(DEOLDIFY_WEIGHT_URL)
state_dict = paddle.load(weight_path)
self.model.load_dict(state_dict)
......
......@@ -19,7 +19,7 @@ import glob
import numpy as np
from tqdm import tqdm
from paddle.utils.download import get_path_from_url
from ppgan.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames
from .base_predictor import BasePredictor
......@@ -138,8 +138,7 @@ class EDVRPredictor(BasePredictor):
self.output = os.path.join(output, 'EDVR')
if weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__))
weight_path = get_path_from_url(EDVR_WEIGHT_URL, cur_path)
weight_path = get_path_from_url(EDVR_WEIGHT_URL)
self.weight_path = weight_path
......
......@@ -25,7 +25,7 @@ from skimage.transform import resize
from scipy.spatial import ConvexHull
import paddle
from paddle.utils.download import get_path_from_url
from ppgan.utils.download import get_path_from_url
from ppgan.utils.animate import normalize_kp
from ppgan.modules.keypoint_detector import KPDetector
from ppgan.models.generators.occlusion_aware import OcclusionAwareGenerator
......@@ -78,8 +78,7 @@ class FirstOrderPredictor(BasePredictor):
}
if weight_path is None:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams'
cur_path = os.path.abspath(os.path.dirname(__file__))
weight_path = get_path_from_url(vox_cpk_weight_url, cur_path)
weight_path = get_path_from_url(vox_cpk_weight_url)
self.weight_path = weight_path
self.output = output
......
......@@ -22,7 +22,7 @@ from tqdm import tqdm
import paddle
from ppgan.models.generators import RRDBNet
from ppgan.utils.video import frames2video, video2frames
from paddle.utils.download import get_path_from_url
from ppgan.utils.download import get_path_from_url
from .base_predictor import BasePredictor
REALSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams'
......@@ -34,8 +34,7 @@ class RealSRPredictor(BasePredictor):
self.output = os.path.join(output, 'RealSR')
self.model = RRDBNet(3, 3, 64, 23)
if weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__))
weight_path = get_path_from_url(REALSR_WEIGHT_URL, cur_path)
weight_path = get_path_from_url(REALSR_WEIGHT_URL)
state_dict = paddle.load(weight_path)
self.model.load_dict(state_dict)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import os.path as osp
import shutil
import requests
import hashlib
import tarfile
import zipfile
import time
from tqdm import tqdm
import logging
from .logger import get_logger
logger = get_logger('ppgan')
PPGAN_HOME = os.path.expanduser("~/.cache/ppgan/")
DOWNLOAD_RETRY_LIMIT = 3
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def get_path_from_url(url, md5sum=None, check_exist=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
from paddle.fluid.dygraph.parallel import ParallelEnv
assert is_url(url), "downloading from {} not a url".format(url)
root_dir = PPGAN_HOME
# parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir)
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().local_rank == 0:
fullpath = _download(url, root_dir, md5sum)
else:
while not os.path.exists(fullpath):
time.sleep(1)
if ParallelEnv().local_rank == 0:
if tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath):
fullpath = _decompress(fullpath)
return fullpath
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
logger.info("Downloading {} from {}".format(fname, url))
req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].strip(os.sep).split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode)
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
return True
return False
def _is_a_single_dir(file_list):
file_name = file_list[0].split(os.sep)[0]
for i in range(1, len(file_list)):
if file_name != file_list[i].split(os.sep)[0]:
return False
return True
......@@ -18,6 +18,8 @@ import sys
from paddle.distributed import ParallelEnv
logger_initialized = {}
def setup_logger(output=None, name="ppgan"):
"""
......@@ -67,3 +69,11 @@ def setup_logger(output=None, name="ppgan"):
logger.addHandler(fh)
return logger
def get_logger(name, output=None):
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
return setup_logger(name=name, output=name)
FILE=$1
URL=https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/$FILE.tar.gz
TAR_FILE=./$FILE.tar.gz
TARGET_DIR=./$FILE/
wget -N $URL -O $TAR_FILE --no-check-certificate
mkdir $TARGET_DIR
tar -zxvf $TAR_FILE -C ../data/
rm $TAR_FILE
rm -rf $TARGET_DIR
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册