未验证 提交 581da7e9 编写于 作者: B BrilliantYuKaimin 提交者: GitHub

add SinGAN model (#576)

* Create singan.md

* Create gradient_penalty.py

* add GradientPenalty

* add images for singan document

* add SinGANModel

* add SinGANGenerator

* add SinGANDiscriminator

* Create discriminator_singan.py

* Create generator_singan.py

* Create singan_model.py

* Create empty_dataset.py

* add EmptyDataset

* Create singan_predictor.py

* add SinGANPredictor

* Create singan.py

* create configs for singan

* add inference for singan

* create tipc config for singan

* Create python_singan_results_fp32.txt

* add tipc prepare for singan

* Update test_train_inference_python.md

* Update readme.md

* Update singan.md

* Update singan_model.py

* Update prepare.sh

* Update train_infer_python.txt

* Update prepare.sh

* Revert "add images for singan document"

This reverts commit f45fe5e55a2588611d951ae84d776c90693788df.

* Update singan.md

* update path format in configs

* update year of copyright

* modify the order of import

* Update singan_predictor.py

* update configs for singan

* Update singan_model.py

* modify urls for singan in prepare.sh

* add pretrained weight for singan

* Update singan_model.py

* Update singan.md

* Create English tutorial for SinGAN
上级 71377845
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle
from ppgan.apps import SinGANPredictor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model checkpoint path")
parser.add_argument("--pretrained_model",
type=str,
default=None,
help="a pretianed model, only trees, stone, mountains, birds, and lightning are implemented.")
parser.add_argument("--mode",
type=str,
default="random_sample",
help="type of model for loading pretrained model")
parser.add_argument("--generate_start_scale",
type=int,
default=0,
help="sample random seed for model's image generation")
parser.add_argument("--seed",
type=int,
default=None,
help="sample random seed for model's image generation")
parser.add_argument("--scale_h",
type=float,
default=1.0,
help="horizontal scale")
parser.add_argument("--scale_v",
type=float,
default=1.0,
help="vertical scale")
parser.add_argument("--ref_image",
type=str,
default=None,
help="reference image for harmonization, editing and paint2image")
parser.add_argument("--mask_image",
type=str,
default=None,
help="mask image for harmonization and editing")
parser.add_argument("--sr_factor",
type=float,
default=4.0,
help="scale for super resolution")
parser.add_argument("--animation_alpha",
type=float,
default=0.9,
help="a parameter determines how close the frames of the sequence remain to the training image")
parser.add_argument("--animation_beta",
type=float,
default=0.9,
help="a parameter controls the smoothness and rate of change in the generated clip")
parser.add_argument("--animation_frames",
type=int,
default=20,
help="frame number of output animation when mode is animation")
parser.add_argument("--animation_duration",
type=float,
default=0.1,
help="duration of each frame in animation")
parser.add_argument("--n_row",
type=int,
default=5,
help="row number of output image grid")
parser.add_argument("--n_col",
type=int,
default=3,
help="column number of output image grid")
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = SinGANPredictor(args.output_path,
args.weight_path,
args.pretrained_model,
args.seed)
predictor.run(args.mode,
args.generate_start_scale,
args.scale_h,
args.scale_v,
args.ref_image,
args.mask_image,
args.sr_factor,
args.animation_alpha,
args.animation_beta,
args.animation_frames,
args.animation_duration,
args.n_row,
args.n_col)
total_iters: 100000
output_dir: output_dir
export_model: null
model:
name: SinGANModel
generator:
name: SinGANGenerator
nfc_init: 32
min_nfc_init: 32
noise_zero_pad: False
discriminator:
name: SinGANDiscriminator
nfc_init: 32
min_nfc_init: 32
gan_criterion:
name: GANLoss
gan_mode: wgangp
loss_weight: 1.0
recon_criterion:
name: MSELoss
loss_weight: 10.0
gp_criterion:
name: GradientPenalty
loss_weight: 0.1
train_image: data/singan/stone.png
scale_factor: 0.75
min_size: 25
is_finetune: False
dataset:
train:
name: EmptyDataset
test:
name: SingleDataset
dataroot: data/singan
num_workers: 0
batch_size: 1
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.0005
milestones: [9600]
gamma: 0.1
optimizer:
optimizer_G:
name: Adam
beta1: 0.5
beta2: 0.999
optimizer_D:
name: Adam
beta1: 0.5
beta2: 0.999
log_config:
interval: 100
visiual_interval: 2000
snapshot_config:
interval: 10000
validate:
interval: -1
save_img: True
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 1
total_iters: 12000
output_dir: output_dir
model:
name: SinGANModel
generator:
name: SinGANGenerator
nfc_init: 32
min_nfc_init: 32
noise_zero_pad: True
discriminator:
name: SinGANDiscriminator
nfc_init: 32
min_nfc_init: 32
gan_criterion:
name: GANLoss
gan_mode: wgangp
loss_weight: 1.0
recon_criterion:
name: MSELoss
loss_weight: 10.0
gp_criterion:
name: GradientPenalty
loss_weight: 0.1
train_image: data/singan/stone.png
scale_factor: 0.75
min_size: 25
is_finetune: True
finetune_scale: 1
color_num: 5
dataset:
train:
name: EmptyDataset
test:
name: SingleDataset
dataroot: data/singan
num_workers: 0
batch_size: 1
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.0005
milestones: [9600]
gamma: 0.1
optimizer:
optimizer_G:
name: Adam
beta1: 0.5
beta2: 0.999
optimizer_D:
name: Adam
beta1: 0.5
beta2: 0.999
log_config:
interval: 100
visiual_interval: 2000
snapshot_config:
interval: 4000
total_iters: 100000
output_dir: output_dir
export_model: null
model:
name: SinGANModel
generator:
name: SinGANGenerator
nfc_init: 32
min_nfc_init: 32
noise_zero_pad: True
discriminator:
name: SinGANDiscriminator
nfc_init: 32
min_nfc_init: 32
gan_criterion:
name: GANLoss
gan_mode: wgangp
loss_weight: 1.0
recon_criterion:
name: MSELoss
loss_weight: 100.0
gp_criterion:
name: GradientPenalty
loss_weight: 0.1
train_image: data/singan/stone.png
scale_factor: 0.793701 # (1/2)^(1/3)
min_size: 18
is_finetune: False
dataset:
train:
name: EmptyDataset
test:
name: SingleDataset
dataroot: data/singan
num_workers: 0
batch_size: 1
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.0005
milestones: [9600]
gamma: 0.1
optimizer:
optimizer_G:
name: Adam
beta1: 0.5
beta2: 0.999
optimizer_D:
name: Adam
beta1: 0.5
beta2: 0.999
log_config:
interval: 100
visiual_interval: 2000
snapshot_config:
interval: 10000
validate:
interval: -1
save_img: True
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 1
total_iters: 100000
output_dir: output_dir
export_model: null
model:
name: SinGANModel
generator:
name: SinGANGenerator
nfc_init: 32
min_nfc_init: 32
noise_zero_pad: True
discriminator:
name: SinGANDiscriminator
nfc_init: 32
min_nfc_init: 32
gan_criterion:
name: GANLoss
gan_mode: wgangp
loss_weight: 1.0
recon_criterion:
name: MSELoss
loss_weight: 10.0
gp_criterion:
name: GradientPenalty
loss_weight: 0.1
train_image: data/singan/stone.png
scale_factor: 0.75
min_size: 25
is_finetune: False
dataset:
train:
name: EmptyDataset
test:
name: SingleDataset
dataroot: data/singan
num_workers: 0
batch_size: 1
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
lr_scheduler:
name: MultiStepDecay
learning_rate: 0.0005
milestones: [9600]
gamma: 0.1
optimizer:
optimizer_G:
name: Adam
beta1: 0.5
beta2: 0.999
optimizer_D:
name: Adam
beta1: 0.5
beta2: 0.999
log_config:
interval: 100
visiual_interval: 2000
snapshot_config:
interval: 10000
validate:
interval: -1
save_img: True
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 1
# SinGAN
## Introduction
SinGAN is a novel unconditional* generative model that is trained using a single image. Traditionally, GANs have been trained on class-specific datasets and capture common features among images of the same class. SinGAN, on the other hand, learns from the overlapping patches at multiple scales of a particular image and learns its internal statistics. Once trained, SinGAN can produce assorted high-quality images of arbitrary sizes and aspect ratios that semantically resemble the training image but contain new object configurations and structures.
** An unconditional GAN creates samples purely from randomized input, while a conditional GAN generates samples based on a "class label" that controls the type of image generated.*
## Usage
### About Config Files
We provide 4 config files for SinGAN model:
- `singan_universal.yaml`
- `singan_sr.yaml`
- `singan_animation.yaml`
- `singan_finetune.yaml`
Among them, `singan_universal.yaml` is a config file suit for all tasks, `singan_sr.yaml` is a config file for super resolution recommended by the author, `singan_animation.yaml` is a config file for animation recommended by the author. Results showed in this document were trained with `singan_universal.yaml`. For *Paint to Image*, we will get better results by finetuning with `singan_finetune.yaml` after training with `singan_universal.yaml`.
### Train
Start training:
```bash
python tools/main.py -c configs/singan_universal.yaml \
-o model.train_image=train_image.png
```
Finetune for "Paint2Image":
```bash
python tools/main.py -c configs/singan_finetune.yaml \
-o model.train_image=train_image.png \
--load weight_saved_in_training.pdparams
```
### Evaluation
Running following command, a random image will be generated. It should be noted that `train_image.png` ought to be in directory `data/singan`, or you can modify the value of `dataset.test.dataroot` in config file manually. Besides, this directory must contain only one image, which is `train_image.png`.
```bash
python tools/main.py -c configs/singan_universal.yaml \
-o model.train_image=train_image.png \
--load weight_saved_in_training.pdparams \
--evaluate-only
```
### Extract Weight for Generator
After training, we need use ``tools/extract_weight.py`` to extract weight of generator from training model which includes both generator and discriminator. Then we can use `applications/tools/singan.py` to achieve diverse application of SinGAN.
```bash
python tools/extract_weight.py weight_saved_in_training.pdparams --net-name netG --output weight_of_generator.pdparams
```
### Inference and Result
*Attention: to use pretrained model, you can replace `--weight_path weight_of_generator.pdparams` in the following commands by `--pretrained_model <model>`, where `<model>` can be `trees`, `stone`, `mountains`, `birds` or `lightning`.*
#### Random Sample
```bash
python applications/tools/singan.py \
--weight_path weight_of_generator.pdparams \
--mode random_sample \
--scale_v 1 \ # vertical scale
--scale_h 1 \ # horizontal scale
--n_row 2 \
--n_col 2
```
|training image|result|
| ---- | ---- |
|![birds](https://user-images.githubusercontent.com/91609464/153211448-2614407b-a30b-467c-b1e5-7db88ff2ca74.png)|![birds-random_sample](https://user-images.githubusercontent.com/91609464/153211573-1af108ba-ad42-438a-94a9-e8f8f3e091eb.png)|
#### Editing & Harmonization
```bash
python applications/tools/singan.py \
--weight_path weight_of_generator.pdparams \
--mode editing \ # or harmonization
--ref_image editing_image.png \
--mask_image mask_of_editing.png \
--generate_start_scale 2
```
|training image|editing image|mask of editing|result|
|----|----|----|----|
|![stone](https://user-images.githubusercontent.com/91609464/153211778-bb94d29d-a2b4-4d04-9900-89b20ae90b90.png)|![stone-edit](https://user-images.githubusercontent.com/91609464/153211867-df3d9035-d320-45ec-8043-488e9da49bff.png)|![stone-edit-mask](https://user-images.githubusercontent.com/91609464/153212047-9620f73c-58d9-48ed-9af7-a11470ad49c8.png)|![stone-edit-mask-result](https://user-images.githubusercontent.com/91609464/153211942-e0e639c2-3ea6-4ade-852b-73757b0bbab0.png)|
#### Super Resolution
```bash
python applications/tools/singan.py \
--weight_path weight_of_generator.pdparams \
--mode sr \
--ref_image image_to_sr.png \
--sr_factor 4
```
|training image|result|
| ---- | ---- |
|![mountains](https://user-images.githubusercontent.com/91609464/153212146-efbbbbd6-e045-477a-87ae-10f121341060.png)|![sr](https://user-images.githubusercontent.com/91609464/153212176-530b7075-e72b-4c05-ad3e-2f2cdfc76dea.png)|
#### Animation
```bash
python applications/tools/singan.py \
--weight_path weight_of_generator.pdparams \
--mode animation \
--animation_alpha 0.6 \ # this parameter determines how close the frames of the sequence remain to the training image
--animation_beta 0.7 \ # this parameter controls the smoothness and rate of change in the generated clip
--animation_frames 20 \ # frames of animation
--animation_duration 0.1 # duration of each frame
```
|training image|animation|
| ---- | ---- |
|![lightning](https://user-images.githubusercontent.com/91609464/153212291-6f8976bd-e873-423e-ab62-77997df2df7a.png)|![animation](https://user-images.githubusercontent.com/91609464/153212372-0543e6d6-5842-472b-af50-8b22670270ae.gif)|
#### Paint to Image
```bash
python applications/tools/singan.py \
--weight_path weight_of_generator.pdparams \
--mode paint2image \
--ref_image paint.png \
--generate_start_scale 2
```
|training image|paint|result|result after finetune|
|----|----|----|----|
|![trees](https://user-images.githubusercontent.com/91609464/153212536-0bb6489d-d488-49e0-a6b5-90ef578c9e4f.png)|![trees-paint](https://user-images.githubusercontent.com/91609464/153212511-ef2c6bea-1f8c-4685-951b-8db589414dfe.png)|![trees-paint2image](https://user-images.githubusercontent.com/91609464/153212531-c080c705-fd58-4ade-aac6-e2134838a75f.png)|![trees-paint2image-finetuned](https://user-images.githubusercontent.com/91609464/153212529-51d8d29b-6b58-4f29-8792-4b2b04f9266e.png)|
## Reference
```
@misc{shaham2019singan,
title={SinGAN: Learning a Generative Model from a Single Natural Image},
author={Tamar Rott Shaham and Tali Dekel and Tomer Michaeli},
year={2019},
eprint={1905.01164},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
# SinGAN
## 简介
SinGAN是一种新的可以从单个自然图像中学习的无条件生成模型。该模型包含一个全卷积生成对抗网络的金字塔结构,每个生成对抗网络负责学习不同在不同比例的图像上的块分布。这允许生成任意大小和纵横比的新样本,具有显著的可变性,但同时保持训练图像的全局结构和精细纹理。与以往单一图像生成方案相比,该方法不局限于纹理图像,也没有条件(即从噪声中生成样本)。
## 使用方法
### 配置说明
我们为SinGAN提供了4个配置文件:
- `singan_universal.yaml`
- `singan_sr.yaml`
- `singan_animation.yaml`
- `singan_finetune.yaml`
其中`singan_universal.yaml`对所有任务都适用配置,`singan_sr.yaml`是官方建议的用于超分任务的配置,`singan_animation.yaml`是官方建议的用于“静图转动”任务的配置。本文档展示的结果均由`singan_universal.yaml`训练而来。对于手绘转照片任务,使用`singan_universal.yaml`训练后再用`singan_finetune.yaml`微调会得到更好的结果。
### 训练
启动训练:
```bash
python tools/main.py -c configs/singan_universal.yaml \
-o model.train_image=训练图片.png
```
为“手绘转照片”任务微调:
```bash
python tools/main.py -c configs/singan_finetune.yaml \
-o model.train_image=训练图片.png \
--load 已经训练好的模型.pdparams
```
### 测试
运行下面的命令,可以随机生成一张图片。需要注意的是,`训练图片.png`应当位于`data/singan`目录下,或者手动调整配置文件中`dataset.test.dataroot`的值。此外,这个目录中只能包含`训练图片.png`这一张图片。
```bash
python tools/main.py -c configs/singan_universal.yaml \
-o model.train_image=训练图片.png \
--load 已经训练好的模型.pdparams \
--evaluate-only
```
### 导出生成器权重
训练结束后,需要使用 ``tools/extract_weight.py`` 来从训练模型(包含了生成器和判别器)中提取生成器的权重来给`applications/tools/singan.py`进行推理,以实现SinGAN的各种应用。
```bash
python tools/extract_weight.py 训练过程中保存的权重文件.pdparams --net-name netG --output 生成器权重文件.pdparams
```
### 推理及结果展示
*注意:您可以下面的命令中的`--weight_path 生成器权重文件.pdparams`可以换成`--pretrained_model <model> `来体验训练好的模型,其中`<model>`可以是`trees`、`stone`、`mountains`、`birds`和`lightning`。*
#### 随机采样
```bash
python applications/tools/singan.py \
--weight_path 生成器权重文件.pdparams \
--mode random_sample \
--scale_v 1 \ # vertical scale
--scale_h 1 \ # horizontal scale
--n_row 2 \
--n_col 2
```
|训练图片|随机采样结果|
| ---- | ---- |
|![birds](https://user-images.githubusercontent.com/91609464/153211448-2614407b-a30b-467c-b1e5-7db88ff2ca74.png)|![birds-random_sample](https://user-images.githubusercontent.com/91609464/153211573-1af108ba-ad42-438a-94a9-e8f8f3e091eb.png)|
#### 图像编辑&风格和谐化
```bash
python applications/tools/singan.py \
--weight_path 生成器权重文件.pdparams \
--mode editing \ # or harmonization
--ref_image 编辑后的图片.png \
--mask_image 编辑区域标注图片.png \
--generate_start_scale 2
```
|训练图片|编辑图片|编辑区域标注|SinGAN生成|
|----|----|----|----|
|![stone](https://user-images.githubusercontent.com/91609464/153211778-bb94d29d-a2b4-4d04-9900-89b20ae90b90.png)|![stone-edit](https://user-images.githubusercontent.com/91609464/153211867-df3d9035-d320-45ec-8043-488e9da49bff.png)|![stone-edit-mask](https://user-images.githubusercontent.com/91609464/153212047-9620f73c-58d9-48ed-9af7-a11470ad49c8.png)|![stone-edit-mask-result](https://user-images.githubusercontent.com/91609464/153211942-e0e639c2-3ea6-4ade-852b-73757b0bbab0.png)|
#### 超分
```bash
python applications/tools/singan.py \
--weight_path 生成器权重文件.pdparams \
--mode sr \
--ref_image 待超分的图片亦即用于训练的图片.png \
--sr_factor 4
```
|训练图片|超分结果|
| ---- | ---- |
|![mountains](https://user-images.githubusercontent.com/91609464/153212146-efbbbbd6-e045-477a-87ae-10f121341060.png)|![sr](https://user-images.githubusercontent.com/91609464/153212176-530b7075-e72b-4c05-ad3e-2f2cdfc76dea.png)|
#### 静图转动
```bash
python applications/tools/singan.py \
--weight_path 生成器权重文件.pdparams \
--mode animation \
--animation_alpha 0.6 \ # this parameter determines how close the frames of the sequence remain to the training image
--animation_beta 0.7 \ # this parameter controls the smoothness and rate of change in the generated clip
--animation_frames 20 \ # frames of animation
--animation_duration 0.1 # duration of each frame
```
|训练图片|动画效果|
| ---- | ---- |
|![lightning](https://user-images.githubusercontent.com/91609464/153212291-6f8976bd-e873-423e-ab62-77997df2df7a.png)|![animation](https://user-images.githubusercontent.com/91609464/153212372-0543e6d6-5842-472b-af50-8b22670270ae.gif)|
#### 手绘转照片
```bash
python applications/tools/singan.py \
--weight_path 生成器权重文件.pdparams \
--mode paint2image \
--ref_image 手绘图片.png \
--generate_start_scale 2
```
|训练图片|手绘图片|SinGAN生成|SinGAN微调后生成|
|----|----|----|----|
|![trees](https://user-images.githubusercontent.com/91609464/153212536-0bb6489d-d488-49e0-a6b5-90ef578c9e4f.png)|![trees-paint](https://user-images.githubusercontent.com/91609464/153212511-ef2c6bea-1f8c-4685-951b-8db589414dfe.png)|![trees-paint2image](https://user-images.githubusercontent.com/91609464/153212531-c080c705-fd58-4ade-aac6-e2134838a75f.png)|![trees-paint2image-finetuned](https://user-images.githubusercontent.com/91609464/153212529-51d8d29b-6b58-4f29-8792-4b2b04f9266e.png)|
## 参考文献
```
@misc{shaham2019singan,
title={SinGAN: Learning a Generative Model from a Single Natural Image},
author={Tamar Rott Shaham and Tali Dekel and Tomer Michaeli},
year={2019},
eprint={1905.01164},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
......@@ -34,3 +34,4 @@ from .photopen_predictor import PhotoPenPredictor
from .recurrent_vsr_predictor import (PPMSVSRPredictor, BasicVSRPredictor, \
BasiVSRPlusPlusPredictor, IconVSRPredictor, \
PPMSVSRLargePredictor)
from .singan_predictor import SinGANPredictor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import cv2
import math
import skimage
import imageio
import paddle
import paddle.nn.functional as F
import paddle.vision.transforms as T
from .base_predictor import BasePredictor
from ..models.singan_model import pad_shape
from ppgan.models.generators import SinGANGenerator
from ppgan.utils.download import get_path_from_url
from ppgan.utils.visual import tensor2img, save_image, make_grid
pretrained_weights_url = {
'trees': 'https://paddlegan.bj.bcebos.com/models/singan_universal_trees.pdparams',
'stone': 'https://paddlegan.bj.bcebos.com/models/singan_universal_stone.pdparams',
'mountains': 'https://paddlegan.bj.bcebos.com/models/singan_universal_mountains.pdparams',
'birds': 'https://paddlegan.bj.bcebos.com/models/singan_universal_birds.pdparams',
'lightning': 'https://paddlegan.bj.bcebos.com/models/singan_universal_lightning.pdparams'
}
def imread(path):
return cv2.cvtColor(
cv2.imread(
path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
def imgpath2tensor(path):
return paddle.to_tensor(T.Compose([
T.Transpose(),
T.Normalize(127.5, 127.5)
])(imread(path))).unsqueeze(0)
def dilate_mask(mask, mode):
if mode == "harmonization":
element = skimage.morphology.disk(radius=7)
elif mode == "editing":
element = skimage.morphology.disk(radius=20)
else:
raise NotImplementedError('mode %s is not implemented' % mode)
mask = skimage.morphology.binary_dilation(mask, selem=element)
mask = skimage.filters.gaussian(mask, sigma=5)
return mask
class SinGANPredictor(BasePredictor):
def __init__(self,
output_path='output_dir',
weight_path=None,
pretrained_model=None,
seed=None):
self.output_path = output_path
if weight_path is None:
if pretrained_model in pretrained_weights_url.keys():
weight_path = get_path_from_url(
pretrained_weights_url[pretrained_model])
else:
raise ValueError(
'Predictor need a weight path or a pretrained model.')
checkpoint = paddle.load(weight_path)
self.scale_num = checkpoint['scale_num'].item()
self.coarsest_shape = checkpoint['coarsest_shape'].tolist()
self.nfc_init = checkpoint['nfc_init'].item()
self.min_nfc_init = checkpoint['min_nfc_init'].item()
self.num_layers = checkpoint['num_layers'].item()
self.ker_size = checkpoint['ker_size'].item()
self.noise_zero_pad = checkpoint['noise_zero_pad'].item()
self.generator = SinGANGenerator(self.scale_num,
self.coarsest_shape,
self.nfc_init,
self.min_nfc_init,
3,
self.num_layers,
self.ker_size,
self.noise_zero_pad)
self.generator.set_state_dict(checkpoint)
self.generator.eval()
self.scale_factor = self.generator.scale_factor.item()
self.niose_pad_size = 0 if self.noise_zero_pad \
else self.generator._pad_size
if seed is not None:
paddle.seed(seed)
def noise_like(self, x):
return paddle.randn(pad_shape(x.shape, self.niose_pad_size))
def run(self,
mode='random_sample',
generate_start_scale=0,
scale_h=1.0,
scale_v=1.0,
ref_image=None,
mask_image=None,
sr_factor=4,
animation_alpha=0.9,
animation_beta=0.9,
animation_frames=20,
animation_duration=0.1,
n_row=5,
n_col=3):
# check config
if mode not in ['random_sample',
'sr', 'animation',
'harmonization',
'editing', 'paint2image']:
raise ValueError(
'Only random_sample, sr, animation, harmonization, \
editing and paint2image is implemented.')
if mode in ['sr', 'harmonization', 'editing', 'paint2image'] and \
ref_image is None:
raise ValueError(
'When mode is sr, harmonization, editing, or \
paint2image, a reference image must be privided.')
if mode in ['harmonization', 'editing'] and mask_image is None:
raise ValueError(
'When mode is harmonization or editing, \
a mask image must be privided.')
if mode == 'animation':
batch_size = animation_frames
elif mode == 'random_sample':
batch_size = n_row * n_col
else:
batch_size = 1
# prepare input
if mode == 'harmonization' or mode == 'editing' or mode == 'paint2image':
ref = imgpath2tensor(ref_image)
x_init = F.interpolate(
ref, None,
self.scale_factor ** (self.scale_num - generate_start_scale),
'bicubic')
x_init = F.interpolate(
x_init, None, 1 / self.scale_factor, 'bicubic')
elif mode == 'sr':
ref = imgpath2tensor(ref_image)
sr_iters = math.ceil(math.log(sr_factor, 1 / self.scale_factor))
sr_scale_factor = sr_factor ** (1 / sr_factor)
x_init = F.interpolate(ref, None, sr_scale_factor, 'bicubic')
else:
x_init = paddle.zeros([
batch_size,
self.coarsest_shape[1],
int(self.coarsest_shape[2] * scale_v),
int(self.coarsest_shape[3] * scale_h)])
# forward
if mode == 'sr':
for _ in range(sr_iters):
out = self.generator([self.noise_like(x_init)], x_init, -1, -1)
x_init = F.interpolate(out, None, sr_scale_factor, 'bicubic')
else:
z_pyramid = [
self.noise_like(
F.interpolate(
x_init, None, 1 / self.scale_factor ** i))
for i in range(self.scale_num - generate_start_scale)]
if mode == 'animation':
a = animation_alpha
b = animation_beta
for i in range(len(z_pyramid)):
z = paddle.chunk(z_pyramid[i], batch_size)
if i == 0 and generate_start_scale == 0:
z_0 = F.interpolate(
self.generator.z_fixed,
pad_shape(x_init.shape[-2:], self.niose_pad_size),
None, 'bicubic')
else:
z_0 = 0
z_1 = z_0
z_2 = 0.95 * z_1 + 0.05 * z[0]
for j in range(len(z)):
z[j] = a * z_0 + (1 - a) * (z_2 + b * (z_2 - z_1) + (1 - b) * z[j])
z_1 = z_2
z_2 = z[j]
z = paddle.concat(z)
z_pyramid[i] = z
out = self.generator(z_pyramid, x_init, self.scale_num - 1, generate_start_scale)
# postprocess and save
os.makedirs(self.output_path, exist_ok=True)
if mode == 'animation':
frames = [tensor2img(x) for x in out.chunk(animation_frames)]
imageio.mimsave(
os.path.join(self.output_path, 'animation.gif'),
frames, 'GIF', duration=animation_duration)
else:
if mode == 'harmonization' or mode == 'editing':
mask = cv2.imread(mask_image, cv2.IMREAD_GRAYSCALE)
mask = paddle.to_tensor(dilate_mask(mask, mode), 'float32')
out = F.interpolate(out, mask.shape, None, 'bicubic')
out = (1 - mask) * ref + mask * out
elif mode == 'sr':
out = F.interpolate(
out,
[ref.shape[-2] * sr_factor, ref.shape[-1] * sr_factor],
None, 'bicubic')
elif mode == 'paint2image':
out = F.interpolate(out, ref.shape[-2:], None, 'bicubic')
elif mode == 'random_sample':
out = make_grid(out, n_row)
save_image(tensor2img(out), os.path.join(self.output_path, mode + '.png'))
......@@ -29,3 +29,4 @@ from .vsr_reds_multiple_gt_dataset import VSRREDSMultipleGTDataset
from .vsr_vimeo90k_dataset import VSRVimeo90KDataset
from .vsr_folder_dataset import VSRFolderDataset
from .photopen_dataset import PhotoPenDataset
from .empty_dataset import EmptyDataset
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_dataset import BaseDataset
from .builder import DATASETS
@DATASETS.register()
class EmptyDataset(BaseDataset):
'''
Dataset for models who don't need a dataset.
'''
def __init__(self, size=1):
super().__init__()
self.size = size
self.data_infos = self.prepare_data_infos()
def prepare_data_infos(self):
return [{i: 0} for i in range(self.size)]
......@@ -34,3 +34,4 @@ from .basicvsr_model import BasicVSRModel
from .mpr_model import MPRModel
from .photopen_model import PhotoPenModel
from .msvsr_model import MultiStageVSRModel
from .singan_model import SinGANModel
......@@ -4,5 +4,6 @@ from .pixel_loss import L1Loss, MSELoss, CharbonnierLoss, \
CalcStyleEmdLoss, CalcContentReltLoss, \
CalcContentLoss, CalcStyleLoss, EdgeLoss
from .photopen_perceptual_loss import PhotoPenPerceptualLoss
from .gradient_penalty import GradientPenalty
from .builder import build_criterion
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .builder import CRITERIONS
@CRITERIONS.register()
class GradientPenalty():
def __init__(self, loss_weight=1.0):
self.loss_weight = loss_weight
def __call__(self, net, real, fake):
batch_size = real.shape[0]
alpha = paddle.rand([batch_size])
for _ in range(real.ndim - 1):
alpha = paddle.unsqueeze(alpha, -1)
interpolate = alpha * real + (1 - alpha) * fake
interpolate.stop_gradient = False
interpolate_pred = net(interpolate)
gradient = paddle.grad(outputs=interpolate_pred,
inputs=interpolate,
grad_outputs=paddle.ones_like(interpolate_pred),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
gradient_penalty = ((gradient.norm(2, 1) - 1) ** 2).mean()
return gradient_penalty * self.loss_weight
......@@ -24,3 +24,4 @@ from .discriminator_starganv2 import StarGANv2Discriminator
from .discriminator_firstorder import FirstOrderDiscriminator
from .discriminator_lapstyle import LapStyleDiscriminator
from .discriminator_photopen import MultiscaleDiscriminator
from .discriminator_singan import SinGANDiscriminator
# code was based on https://github.com/tamarott/SinGAN
import paddle.nn as nn
from ..generators.generator_singan import ConvBlock
from .builder import DISCRIMINATORS
@DISCRIMINATORS.register()
class SinGANDiscriminator(nn.Layer):
def __init__(self,
nfc=32,
min_nfc=32,
input_nc=3,
num_layers=5,
ker_size=3,
padd_size=0):
super(SinGANDiscriminator, self).__init__()
self.head = ConvBlock(input_nc, nfc, ker_size, padd_size, 1)
self.body = nn.Sequential()
for i in range(num_layers - 2):
N = int(nfc / pow(2, (i + 1)))
block = ConvBlock(max(2 * N, min_nfc), max(N, min_nfc), ker_size, padd_size, 1)
self.body.add_sublayer('block%d' % (i + 1), block)
self.tail = nn.Conv2D(max(N, min_nfc), 1, ker_size, 1, padd_size)
def forward(self, x):
x = self.head(x)
x = self.body(x)
x = self.tail(x)
return x
......@@ -38,3 +38,4 @@ from .pan import PAN
from .generater_photopen import SPADEGenerator
from .basicvsr_plus_plus import BasicVSRPlusPlus
from .msvsr import MSVSR
from .generator_singan import SinGANGenerator
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# code was based on https://github.com/tamarott/SinGAN
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .builder import GENERATORS
class ConvBlock(nn.Sequential):
def __init__(self, in_channel, out_channel, ker_size, padd, stride):
super(ConvBlock,self).__init__()
self.add_sublayer('conv', nn.Conv2D(in_channel ,out_channel, ker_size, stride, padd)),
self.add_sublayer('norm', nn.BatchNorm2D(out_channel)),
self.add_sublayer('LeakyRelu', nn.LeakyReLU(0.2))
class GeneratorConcatSkip2CleanAdd(nn.Layer):
def __init__(self, nfc=32, min_nfc=32, input_nc=3, num_layers=5, ker_size=3, padd_size=0):
super(GeneratorConcatSkip2CleanAdd, self).__init__()
self.head = ConvBlock(input_nc, nfc, ker_size, padd_size, 1)
self.body = nn.Sequential()
for i in range(num_layers - 2):
N = int(nfc / pow(2, i + 1))
block = ConvBlock(max(2 * N, min_nfc), max(N, min_nfc), ker_size, padd_size, 1)
self.body.add_sublayer('block%d' % (i + 1), block)
self.tail = nn.Sequential(
nn.Conv2D(max(N, min_nfc), input_nc, ker_size, 1, padd_size),
nn.Tanh())
def forward(self, x, y):
x = self.head(x)
x = self.body(x)
x = self.tail(x)
ind = int((y.shape[2] - x.shape[2]) / 2)
y = y[:, :, ind: (y.shape[2] - ind), ind: (y.shape[3] - ind)]
return x + y
@GENERATORS.register()
class SinGANGenerator(nn.Layer):
def __init__(self,
scale_num,
coarsest_shape,
nfc_init=32,
min_nfc_init=32,
input_nc=3,
num_layers=5,
ker_size=3,
noise_zero_pad=True):
super().__init__()
nfc_list = [min(nfc_init * pow(2, math.floor(i / 4)), 128) for i in range(scale_num)]
min_nfc_list = [min(min_nfc_init * pow(2, math.floor(i / 4)), 128) for i in range(scale_num)]
self.generators = nn.LayerList([
GeneratorConcatSkip2CleanAdd(
nfc, min_nfc, input_nc, num_layers,
ker_size, 0
) for nfc, min_nfc in zip(nfc_list, min_nfc_list)])
self._scale_num = scale_num
self._pad_size = int((ker_size - 1) / 2 * num_layers)
self.noise_pad = nn.Pad2D(self._pad_size if noise_zero_pad else 0)
self.image_pad = nn.Pad2D(self._pad_size)
self._noise_zero_pad = noise_zero_pad
self._coarsest_shape = coarsest_shape
self.register_buffer('scale_num', paddle.to_tensor(scale_num, 'int32'), True)
self.register_buffer('coarsest_shape', paddle.to_tensor(coarsest_shape, 'int32'), True)
self.register_buffer('nfc_init', paddle.to_tensor(nfc_init, 'int32'), True)
self.register_buffer('min_nfc_init', paddle.to_tensor(min_nfc_init, 'int32'), True)
self.register_buffer('num_layers', paddle.to_tensor(num_layers, 'int32'), True)
self.register_buffer('ker_size', paddle.to_tensor(ker_size, 'int32'), True)
self.register_buffer('noise_zero_pad', paddle.to_tensor(noise_zero_pad, 'bool'), True)
self.register_buffer('sigma', paddle.ones([scale_num]), True)
self.register_buffer('scale_factor', paddle.ones([1]), True)
self.register_buffer(
'z_fixed',
paddle.randn(
F.pad(
paddle.zeros(coarsest_shape),
[0 if noise_zero_pad else self._pad_size] * 4).shape), True)
def forward(self, z_pyramid, x_prev, stop_scale, start_scale=0):
stop_scale %= self._scale_num
start_scale %= self._scale_num
for i, scale in enumerate(range(start_scale, stop_scale + 1)):
x_prev = self.image_pad(x_prev)
z = self.noise_pad(z_pyramid[i] * self.sigma[scale]) + x_prev
x_prev = self.generators[scale](
z.detach(),
x_prev.detach()
)
if scale < stop_scale:
x_prev = F.interpolate(x_prev,
F.pad(z_pyramid[i + 1], [0 if self._noise_zero_pad else -self._pad_size] * 4).shape[-2:],
None, 'bicubic')
return x_prev
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import math
import warnings
from collections import OrderedDict
from sklearn.cluster import KMeans
import paddle
import paddle.nn.functional as F
import paddle.vision.transforms as T
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .criterions.builder import build_criterion
from .discriminators.builder import build_discriminator
from ..solver import build_lr_scheduler, build_optimizer
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
def pad_shape(shape, pad_size):
shape[-2] += 2 * pad_size
shape[-1] += 2 * pad_size
return shape
def quant(x, num):
n, c, h, w = x.shape
kmeans = KMeans(num, random_state=0).fit(x.transpose([0, 2, 3, 1]).reshape([-1, c]))
centers = kmeans.cluster_centers_
x = centers[kmeans.labels_].reshape([n, h, w, c]).transpose([0, 3, 1, 2])
return paddle.to_tensor(x, 'float32'), centers
def quant_to_centers(x, centers):
n, c, h, w = x.shape
num = centers.shape[0]
kmeans = KMeans(num, init=centers, n_init=1).fit(x.transpose([0, 2, 3, 1]).reshape([-1, c]))
x = centers[kmeans.labels_].reshape([n, h, w, c]).transpose([0, 3, 1, 2])
return paddle.to_tensor(x, 'float32')
@MODELS.register()
class SinGANModel(BaseModel):
def __init__(self,
generator,
discriminator,
gan_criterion=None,
recon_criterion=None,
gp_criterion=None,
train_image=None,
scale_factor=0.75,
min_size=25,
is_finetune=False,
finetune_scale=1,
color_num=5,
gen_iters=3,
disc_iters=3,
noise_amp_init=0.1):
super(SinGANModel, self).__init__()
# setup config
self.gen_iters = gen_iters
self.disc_iters = disc_iters
self.min_size = min_size
self.is_finetune = is_finetune
self.noise_amp_init = noise_amp_init
self.train_image = T.Compose([
T.Transpose(),
T.Normalize(127.5, 127.5)
])(cv2.cvtColor(cv2.imread(train_image, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB))
self.train_image = paddle.to_tensor(self.train_image).unsqueeze(0)
self.scale_num = math.ceil(math.log(
self.min_size / min(self.train_image.shape[-2:]),
scale_factor)) + 1
self.scale_factor = math.pow(
self.min_size / min(self.train_image.shape[-2:]),
1 / (self.scale_num - 1))
self.reals = [
F.interpolate(self.train_image, None, self.scale_factor ** i, 'bicubic')
for i in range(self.scale_num - 1, -1, -1)]
# build generator
generator['scale_num'] = self.scale_num
generator['coarsest_shape'] =self.reals[0].shape
self.nets['netG'] = build_generator(generator)
self.niose_pad_size = 0 if generator.get('noise_zero_pad', True) \
else self.nets['netG']._pad_size
self.nets['netG'].scale_factor = paddle.to_tensor(self.scale_factor, 'float32')
# build discriminator
nfc_init = discriminator.pop('nfc_init', 32)
min_nfc_init = discriminator.pop('min_nfc_init', 32)
for i in range(self.scale_num):
discriminator['nfc'] = min(nfc_init * pow(2, math.floor(i / 4)), 128)
discriminator['min_nfc'] = min(min_nfc_init * pow(2, math.floor(i / 4)), 128)
self.nets[f'netD{i}'] = build_discriminator(discriminator)
# build criterion
self.gan_criterion = build_criterion(gan_criterion)
self.recon_criterion = build_criterion(recon_criterion)
self.gp_criterion = build_criterion(gp_criterion)
if self.is_finetune:
self.finetune_scale = finetune_scale
self.quant_real, self.quant_centers = quant(self.reals[finetune_scale], color_num)
# setup training config
self.lr_schedulers = OrderedDict()
self.current_scale = (finetune_scale if self.is_finetune else 0) - 1
self.current_iter = 0
def set_total_iter(self, total_iter):
super().set_total_iter(total_iter)
if self.is_finetune:
self.scale_iters = total_iter
else:
self.scale_iters = math.ceil(total_iter / self.scale_num)
def setup_lr_schedulers(self, cfg):
for i in range(self.scale_num):
self.lr_schedulers[f"lr{i}"] = build_lr_scheduler(cfg)
return self.lr_schedulers
def setup_optimizers(self, lr_schedulers, cfg):
for i in range(self.scale_num):
self.optimizers[f'optim_netG{i}'] = build_optimizer(
cfg['optimizer_G'], lr_schedulers[f"lr{i}"], self.nets[f'netG'].generators[i].parameters())
self.optimizers[f'optim_netD{i}'] = build_optimizer(
cfg['optimizer_D'], lr_schedulers[f"lr{i}"], self.nets[f'netD{i}'].parameters())
return self.optimizers
def setup_input(self, input):
pass
def backward_D(self):
self.loss_D_real = self.gan_criterion(self.pred_real, True, True)
self.loss_D_fake = self.gan_criterion(self.pred_fake, False, True)
self.loss_D_gp = self.gp_criterion(self.nets[f'netD{self.current_scale}'],
self.real_img,
self.fake_img)
self.loss_D = self.loss_D_real + self.loss_D_fake + self.loss_D_gp
self.loss_D.backward()
self.losses[f'scale{self.current_scale}/D_total_loss'] = self.loss_D
self.losses[f'scale{self.current_scale}/D_real_loss'] = self.loss_D_real
self.losses[f'scale{self.current_scale}/D_fake_loss'] = self.loss_D_fake
self.losses[f'scale{self.current_scale}/D_gradient_penalty'] = self.loss_D_gp
def backward_G(self):
self.loss_G_gan = self.gan_criterion(self.pred_fake, True, False)
self.loss_G_recon = self.recon_criterion(self.recon_img, self.real_img)
self.loss_G = self.loss_G_gan + self.loss_G_recon
self.loss_G.backward()
self.losses[f'scale{self.current_scale}/G_adv_loss'] = self.loss_G_gan
self.losses[f'scale{self.current_scale}/G_recon_loss'] = self.loss_G_recon
def scale_prepare(self):
self.real_img = self.reals[self.current_scale]
self.lr_scheduler = self.lr_schedulers[f"lr{self.current_scale}"]
for i in range(self.current_scale):
self.optimizers.pop(f'optim_netG{i}', None)
self.optimizers.pop(f'optim_netD{i}', None)
self.losses.clear()
self.visual_items.clear()
self.visual_items[f'real_img_scale{self.current_scale}'] = self.real_img
if self.is_finetune:
self.visual_items['quant_real'] = self.quant_real
self.recon_prev = paddle.zeros_like(self.reals[0])
if self.current_scale > 0:
z_pyramid = []
for i in range(self.current_scale):
if i == 0:
z = self.nets['netG'].z_fixed
else:
z = paddle.zeros(
pad_shape(
self.reals[i].shape, self.niose_pad_size))
z_pyramid.append(z)
self.recon_prev = self.nets['netG'](
z_pyramid, self.recon_prev,
self.current_scale - 1, 0).detach()
self.recon_prev = F.interpolate(
self.recon_prev, self.real_img.shape[-2:], None, 'bicubic')
if self.is_finetune:
self.recon_prev = quant_to_centers(self.recon_prev, self.quant_centers)
self.nets['netG'].sigma[self.current_scale] = F.mse_loss(
self.real_img, self.recon_prev
).sqrt() * self.noise_amp_init
for i in range(self.scale_num):
self.set_requires_grad(self.nets['netG'].generators[i], i == self.current_scale)
def forward(self):
if not self.is_finetune:
self.fake_img = self.nets['netG'](
self.z_pyramid,
paddle.zeros(
pad_shape(self.z_pyramid[0].shape, -self.niose_pad_size)),
self.current_scale, 0)
else:
x_prev = self.nets['netG'](
self.z_pyramid[:self.finetune_scale],
paddle.zeros(
pad_shape(self.z_pyramid[0].shape, -self.niose_pad_size)),
self.finetune_scale - 1, 0)
x_prev = F.interpolate(x_prev, self.z_pyramid[self.finetune_scale].shape[-2:], None, 'bicubic')
x_prev_quant = quant_to_centers(x_prev, self.quant_centers)
self.fake_img = self.nets['netG'](
self.z_pyramid[self.finetune_scale:],
x_prev_quant,
self.current_scale, self.finetune_scale)
self.recon_img = self.nets['netG'](
[(paddle.randn if self.current_scale == 0 else paddle.zeros)(
pad_shape(self.real_img.shape, self.niose_pad_size))],
self.recon_prev,
self.current_scale,
self.current_scale)
self.pred_real = self.nets[f'netD{self.current_scale}'](self.real_img)
self.pred_fake = self.nets[f'netD{self.current_scale}'](
self.fake_img.detach() if self.update_D else self.fake_img)
self.visual_items[f'fake_img_scale{self.current_scale}'] = self.fake_img
self.visual_items[f'recon_img_scale{self.current_scale}'] = self.recon_img
if self.is_finetune:
self.visual_items[f'prev_img_scale{self.current_scale}'] = x_prev
self.visual_items[f'quant_prev_img_scale{self.current_scale}'] = x_prev_quant
def train_iter(self, optimizers=None):
if self.current_iter % self.scale_iters == 0:
self.current_scale += 1
self.scale_prepare()
self.z_pyramid = [paddle.randn(
pad_shape(self.reals[i].shape, self.niose_pad_size))
for i in range(self.current_scale + 1)]
self.update_D = (self.current_iter % (self.disc_iters + self.gen_iters) < self.disc_iters)
self.set_requires_grad(self.nets[f'netD{self.current_scale}'], self.update_D)
self.forward()
if self.update_D:
optimizers[f'optim_netD{self.current_scale}'].clear_grad()
self.backward_D()
optimizers[f'optim_netD{self.current_scale}'].step()
else:
optimizers[f'optim_netG{self.current_scale}'].clear_grad()
self.backward_G()
optimizers[f'optim_netG{self.current_scale}'].step()
self.current_iter += 1
def test_iter(self, metrics=None):
z_pyramid = [paddle.randn(
pad_shape(self.reals[i].shape, self.niose_pad_size))
for i in range(self.scale_num)]
self.nets['netG'].eval()
fake_img = self.nets['netG'](
z_pyramid,
paddle.zeros(pad_shape(z_pyramid[0].shape, -self.niose_pad_size)),
self.scale_num - 1, 0)
self.visual_items['fake_img_test'] = fake_img
with paddle.no_grad():
if metrics is not None:
for metric in metrics.values():
metric.update(fake_img, self.train_image)
self.nets['netG'].train()
class InferGenerator(paddle.nn.Layer):
def set_config(self, generator, noise_shapes, scale_num):
self.generator = generator
self.noise_shapes = noise_shapes
self.scale_num = scale_num
def forward(self, x):
coarsest_shape = self.generator._coarsest_shape
z_pyramid = [paddle.randn(shp) for shp in self.noise_shapes]
x_init = paddle.zeros(coarsest_shape)
out = self.generator(z_pyramid, x_init, self.scale_num - 1, 0)
return out
def export_model(self,
export_model=None,
output_dir=None,
inputs_size=None,
export_serving_model=False):
noise_shapes = [pad_shape(x.shape, self.niose_pad_size) for x in self.reals]
infer_generator = self.InferGenerator()
infer_generator.set_config(self.nets['netG'], noise_shapes, self.scale_num)
paddle.jit.save(infer_generator,
os.path.join(output_dir, "singan_random_sample"),
input_spec=[1])
===========================train_params===========================
model_name:singan
python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=100|whole_train_whole_infer=100000
output_dir:./output/
snapshot_config.interval:lite_train_lite_infer=25|whole_train_whole_infer=10000
pretrained_model:null
train_model_name:singan*/*checkpoint.pdparams
train_infer_img_dir:./data/stone
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/singan_universal.yaml --seed 123 -o log_config.interval=50
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/singan_universal.yaml --inputs_size=1 --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:singan_random_sample
train_model:./inference/singan/singan_random_sample
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type singan --seed 123 -c configs/singan_universal.yaml --output_path test_tipc/output/
--device:cpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
\ No newline at end of file
......@@ -15,6 +15,7 @@ Linux端基础训练预测功能测试的主程序为`test_train_inference_pytho
| FOMM |FOMM | 生成 | 支持 | 多机多卡 | | |
| BasicVSR |BasicVSR | 超分 | 支持 | 多机多卡 | | |
|PP-MSVSR|PP-MSVSR | 超分|
|SinGAN|SinGAN | 生成|支持|
- 预测相关:预测功能汇总如下,
......
......@@ -66,6 +66,12 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
rm -rf ./data/DIV2K*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/DIV2KandSet14.tar --no-check-certificate
cd ./data/ && tar xf DIV2KandSet14.tar && cd ../ ;;
singan)
rm -rf ./data/SinGAN*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/singan-official_images.zip --no-check-certificate
cd ./data/ && unzip -q singan-official_images.zip && cd ../ ;;
mkdir -p ./data/singan
mv ./data/SinGAN-official_images/Images/stone.png ./data/singan
esac
elif [ ${MODE} = "whole_train_whole_infer" ];then
if [ ${model_name} == "pix2pix" ]; then
......@@ -76,6 +82,12 @@ elif [ ${MODE} = "whole_train_whole_infer" ];then
rm -rf ./data/horse2zebra*
wget -nc -P ./data/ https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip --no-check-certificate
cd ./data/ && unzip horse2zebra.zip && cd ../
elif [ ${model_name} == "singan" ]; then
rm -rf ./data/SinGAN*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/singan-official_images.zip --no-check-certificate
cd ./data/ && unzip -q singan-official_images.zip && cd ../
mkdir -p ./data/singan
mv ./data/SinGAN-official_images/Images/stone.png ./data/singan
fi
elif [ ${MODE} = "lite_train_whole_infer" ];then
if [ ${model_name} == "pix2pix" ]; then
......@@ -102,6 +114,12 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then
rm -rf ./data/reds*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/reds_lite.tar --no-check-certificate
cd ./data/ && tar xf reds_lite.tar && cd ../
elif [ ${model_name} == "singan" ]; then
rm -rf ./data/SinGAN*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/singan-official_images.zip --no-check-certificate
cd ./data/ && unzip -q singan-official_images.zip && cd ../
mkdir -p ./data/singan
mv ./data/SinGAN-official_images/Images/stone.png ./data/singan
fi
elif [ ${MODE} = "whole_infer" ];then
if [ ${model_name} = "pix2pix" ]; then
......@@ -145,6 +163,14 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference https://paddlegan.bj.bcebos.com/static_model/msvsr.tar --no-check-certificate
cd ./inference && tar xf msvsr.tar && cd ../
cd ./data/ && tar xf reds_lite.tar && cd ../
elif [ ${model_name} == "singan" ]; then
rm -rf ./data/SinGAN*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/singan-official_images.zip --no-check-certificate
wget -nc -P ./inference https://paddlegan.bj.bcebos.com/datasets/singan.zip --no-check-certificate
cd ./data/ && unzip -q singan-official_images.zip && cd ../
cd ./inference/ && unzip -q singan.zip && cd ../
mkdir -p ./data/singan
mv ./data/SinGAN-official_images/Images/stone.png ./data/singan
fi
fi
......@@ -25,6 +25,7 @@
| FOMM |FOMM | 生成 | 支持 | 多机多卡 | | |
| BasicVSR |BasicVSR | 超分 | 支持 | 多机多卡 | | |
|PP-MSVSR|PP-MSVSR | 超分|
|SinGAN|SinGAN | 生成| 支持 |
......
......@@ -15,7 +15,7 @@ from ppgan.metrics import build_metric
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \
"edvr", "fom", "stylegan2", "basicvsr", "msvsr"]
"edvr", "fom", "stylegan2", "basicvsr", "msvsr", "singan"]
def parse_args():
......@@ -304,6 +304,15 @@ def main():
metric_file = os.path.join(args.output_path, model_type, "metric.txt")
for metric in metrics.values():
metric.update(out_img, gt_img, is_seq=True)
elif model_type == "singan":
predictor.run()
prediction = output_handle.copy_to_cpu()
prediction = paddle.to_tensor(prediction)
image_numpy = tensor2img(prediction, min_max)
save_image(image_numpy, os.path.join(args.output_path, "singan/{}.png".format(i)))
metric_file = os.path.join(args.output_path, "singan/metric.txt")
for metric in metrics.values():
metric.update(prediction, data['A'])
if metrics:
log_file = open(metric_file, 'a')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册