未验证 提交 b905c5d7 编写于 作者: L lvmengsi 提交者: GitHub

fix spade infer for 1.6(#3575) (#3668)

* fix spade infer (#3575)

* fix spade

* update SPADE

* update spade

* update infer

* update infer

* update infer

* fix_spade_init
上级 6d611bb8
......@@ -16,7 +16,7 @@
## 模型简介
本图像生成模型库包含CGAN\[[3](#参考文献)\], DCGAN\[[4](#参考文献)\], Pix2Pix\[[5](#参考文献)\], CycleGAN\[[6](#参考文献)\], StarGAN\[[7](#参考文献)\], AttGAN\[[8](#参考文献)\], STGAN\[[9](#参考文献)\]
本图像生成模型库包含CGAN\[[3](#参考文献)\], DCGAN\[[4](#参考文献)\], Pix2Pix\[[5](#参考文献)\], CycleGAN\[[6](#参考文献)\], StarGAN\[[7](#参考文献)\], AttGAN\[[8](#参考文献)\], STGAN\[[9](#参考文献)\], SPADE\[[13](#参考文献)\]
注意:
1. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。
......@@ -86,6 +86,7 @@ StarGAN,AttGAN和STGAN采用celeba\[[11](#参考文献)\]数据集进行属性
通过指定dataset参数来下载相应的数据集。
StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)数据集可以自行下载。
SPADE使用的[cityscapes](https://www.cityscapes-dataset.com)数据集可以自行下载。下载完成后新建一个目录data/cityscapes/,并在目录下准备3个子目录,分别是真实图片、分割图、实例图。准备一个train_list和test_list,每一行的顺序是分割图\t真实图\t实例图。
**自定义数据集:**
如果您要使用自定义的数据集,只要设置成对应的生成模型所需要的数据格式,并放在data文件夹下,然后把`--dataset`参数设置成您自定义数据集的名称,data_reader.py文件就会自动去data文件夹中寻找数据。
......@@ -113,6 +114,7 @@ StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects
- 每个GAN都给出了一份运行示例,放在scripts文件夹内,用户可以直接运行训练脚本快速开始训练。
- 用户可以通过设置`--model_net`参数来选择想要训练的模型,通过设置`--dataset`参数来选择训练所需要的数据集。
- SPADE模型的训练需要在主目录下新建一个VGG19_pretrained目录,从[该链接](https://paddle-imagenet-models-name.bj.bcebos.com/VGG19_pretrained.tar)下载在ImageNet上预训练好的VGG19模型,解压之后把VGG19模型的参数名改成`vgg19_`开头的参数名。
### 模型测试
模型测试是利用训练完成的生成模型进行图像生成。infer.py是主要的执行程序,调用示例如下:
......@@ -183,6 +185,7 @@ STGAN的效果图(图片属性分别为:original image, Bald, Bangs, Black Hai
| StarGAN | [StarGAN的预训练模型](https://paddle-gan-models.bj.bcebos.com/stargan_G.tar.gz) |
| AttGAN | [AttGAN的预训练模型](https://paddle-gan-models.bj.bcebos.com/attgan_G.tar.gz) |
| STGAN | [STGAN的预训练模型](https://paddle-gan-models.bj.bcebos.com/stgan_G.tar.gz) |
| SPADE | [SPADE的预训练模型]() ([SPADE需要的vgg预训练模型]())
## 进阶使用
......@@ -202,6 +205,8 @@ AttGAN利用分类损失和重构损失来保证改变特定的属性,可用
STGAN只输入有变化的标签,引入GRU结构,更好的选择变化的属性,可用于人脸特定属性转换。
SPADE提出一种考虑空间语义信息的归一化方法,从而更好的保留语义信息,生成更为逼真的图像,可用于图像翻译。
### 模型概览
- Pix2Pix由一个生成网络和一个判别网络组成。生成网络中编码部分的网络结构都是采用`convolution-batch norm-ReLU`作为基础结构,解码部分的网络结构由`transpose convolution-batch norm-ReLU`组成,判别网络基本是由`convolution-norm-leaky_ReLU`作为基础结构,详细的网络结构可以查看`network/Pix2pix_network.py`文件。生成网络提供两种可选的网络结构:Unet网络结构和普通的encoder-decoder网络结构。网络利用损失函数学习从输入图像到输出图像的映射,生成网络损失函数由GAN的损失函数和L1损失函数组成,判别网络损失函数由GAN的损失函数组成。生成器的网络结构如下图所示:
......@@ -245,6 +250,13 @@ AttGAN的网络结构[8]
STGAN的网络结构[9]
</p>
- SPADE中整体网络结构如下图所示。SPADE在网络中的卷积层使用了[谱归一化](\[[12](#参考文献)\]),把输入图像的语义mask图像作为生成网络输入,拼接了语义mask和生成器的输出为判别网络的输入。SPADE提出了一种基于空间信息的归一化方法\(SPatially-Adaptive \(DE\)normalization\),在进行归一化的时候可以更好的利用语义信息,从而生成更为逼真的图像。更为具体的网络结构可以参考network/SPADE_network.py文件或者论文中的附录部分。
<p align="center">
<img src="images/spade_net.png" width=800 /> <br />
SPADE整体的网络结构[10]
</p>
注意:网络结构中的norm指的是用户可以选用batch norm或者instance norm来搭建自己的网络。
......@@ -291,6 +303,10 @@ STGAN的网络结构[9]
[11] [Deep Learning Face Attributes in the Wild](https://arxiv.org/abs/1411.7766)
[12] [Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957)
[13] [Semantic Image Synthesis with Spatially-Adaptive Normalization](https://arxiv.org/abs/1903.07291)
## 版本更新
......
......@@ -618,6 +618,11 @@ class data_reader(object):
train_list = os.path.join(dataset_dir, 'train.txt')
if self.cfg.train_list is not None:
train_list = self.cfg.train_list
if not os.path.exists(train_list):
print(
"train_list is NOT EXIST!!! Please prepare train list first"
)
sys.exit(1)
train_reader = triplex_reader_creator(
image_dir=dataset_dir,
list_filename=train_list,
......@@ -629,6 +634,11 @@ class data_reader(object):
test_list = os.path.join(dataset_dir, "test.txt")
if self.cfg.test_list is not None:
test_list = self.cfg.test_list
if not os.path.exists(test_list):
print(
"test_list is NOT EXIST!!! Please prepare test list first"
)
sys.exit(1)
test_reader = triplex_reader_creator(
image_dir=dataset_dir,
list_filename=test_list,
......
......@@ -170,12 +170,14 @@ def infer(args):
model = DCGAN_model(args.n_samples)
fake = model.network_G(noise, name="G")
elif args.model_net == 'SPADE':
label_shape = [None, args.label_nc, args.crop_height, args.crop_width]
spade_data_shape = [None, 1, args.crop_height, args.crop_width]
from network.SPADE_network import SPADE_model
model = SPADE_model()
input_label = fluid.layers.data(
name='input_label', shape=data_shape, dtype='float32')
input_ins = fluid.layers.data(
name='input_ins', shape=data_shape, dtype='float32')
input_label = fluid.data(
name='input_label', shape=label_shape, dtype='float32')
input_ins = fluid.data(
name='input_ins', shape=spade_data_shape, dtype='float32')
input_ = fluid.layers.concat([input_label, input_ins], 1)
fake = model.network_G(input_, "generator", cfg=args, is_test=True)
else:
......@@ -320,10 +322,12 @@ def infer(args):
shuffle=False,
batch_size=1,
mode="TEST")
id2name = test_reader.id2name
reader_test = test_reader.make_reader(args, return_name=True)
for data in zip(reader_test()):
data_A, data_B, data_C, name = data[0]
name = name[0]
name = id2name[np.array(name).astype('int32')[0]]
print("read: ", name)
tensor_A = fluid.LoDTensor()
tensor_C = fluid.LoDTensor()
tensor_A.set(data_A, place)
......
......@@ -83,7 +83,7 @@ class CGAN_model(object):
name=name + '_dc2',
output_size=[self.img_w, self.img_h])
out = fluid.layers.reshape(o_dc2, [-1, self.img_w * self.img_h])
return o_dc2
return out
def network_D(self, input, label, name="discriminator"):
# concat image and label
......
......@@ -40,6 +40,7 @@ class SPADE_model(object):
padding=1,
name=name + "_fc",
use_bias=True,
initial="kaiming",
is_test=is_test)
x = self.SPADEResnetBlock(
x,
......@@ -88,6 +89,7 @@ class SPADE_model(object):
padding=1,
name=name + "_conv_img",
use_bias=True,
initial="kaiming",
is_test=is_test)
x = fluid.layers.tanh(x)
......@@ -148,6 +150,7 @@ class SPADE_model(object):
padding=pw,
activation_fn='relu',
name=name + ".mlp_shared.0",
initial="kaiming",
use_bias=True)
gamma = conv2d(
actv,
......@@ -155,6 +158,7 @@ class SPADE_model(object):
ks,
padding=pw,
name=name + ".mlp_gamma",
initial="kaiming",
use_bias=True)
beta = conv2d(
actv,
......@@ -162,6 +166,7 @@ class SPADE_model(object):
ks,
padding=pw,
name=name + ".mlp_beta",
initial="kaiming",
use_bias=True)
param_attr = fluid.ParamAttr(
name=name + ".param_free_norm.weight",
......@@ -219,6 +224,7 @@ def build_discriminator_Nlayers(input,
name=name + ".model0.0",
activation_fn='leaky_relu',
relufactor=0.2,
initial="kaiming",
use_bias=True)
d_dims = d_base_dims
res_list.append(res1)
......@@ -248,6 +254,7 @@ def build_discriminator_Nlayers(input,
0.02,
1,
name + ".model{}.0".format(d_nlayers),
initial="kaiming",
use_bias=True)
res_list.append(o_c4)
return res_list
......@@ -414,7 +414,8 @@ def conv2d_spectral_norm(input,
dtype = helper.input_dtype()
weight_param = fluid.ParamAttr(
name=name + ".weight_orig",
initializer=fluid.initializer.Constant(1.0),
initializer=fluid.initializer.Normal(
loc=0.0, scale=1.0),
trainable=True)
weight = helper.create_parameter(
attr=weight_param,
......@@ -425,7 +426,9 @@ def conv2d_spectral_norm(input,
weight = weight_spectral_norm
if use_bias:
bias_attr = fluid.ParamAttr(
name=name + "_b", initializer=fluid.initializer.Constant(0.0))
name=name + "_b",
initializer=fluid.initializer.Normal(
loc=0.0, scale=1.0))
else:
bias_attr = False
conv = conv2d_with_filter(
......
python infer.py --model_net SPADE --test_list ./data/cityscapes/test_list --load_height 512 --load_width 1024 --crop_height 512 --crop_width 1024 --dataset_dir ./data/cityscapes/ --init_model ./spade_py37/checkpoints/99/
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_fraction_of_gpu_memory_to_use=0.01
CUDA_VISIBLE_DEVICES=0 python train.py --model_net SPADE --dataset cityscapes --train_list train_list --test_list val_list --crop_type Random --batch_size 1 --epoch 200 --load_height 612 --load_width 1124 --crop_height 512 --crop_width 1024 --label_nc 36
CUDA_VISIBLE_DEVICES=0 python train.py --model_net SPADE --dataset cityscapes --train_list ./data/cityscapes/train_list --test_list ./data/cityscapes/val_list --crop_type Random --batch_size 1 --epoch 200 --load_height 612 --load_width 1124 --crop_height 512 --crop_width 1024 --label_nc 36
......@@ -19,6 +19,7 @@ from network.SPADE_network import SPADE_model
from util import utility
import paddle.fluid as fluid
import sys
import os
import time
import network.vgg as vgg
import pickle as pkl
......@@ -316,6 +317,12 @@ class SPADE(object):
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if not os.path.exists(self.cfg.vgg19_pretrain):
print(
"directory VGG19_pretrain NOT EXIST!!! Please download VGG19 first."
)
sys.exit(1)
gen_trainer.vgg.load_vars(exe, gen_trainer.program,
self.cfg.vgg19_pretrain)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册