未验证 提交 d63b2bd3 编写于 作者: L lvmengsi 提交者: GitHub

fix infer and readme (#2518)

* fix infer and readme
上级 58ef369c
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
## 内容 ## 内容
-[简介](#简介) -[简介](#简介)
-[快速开始](#快速开始) -[快速开始](#快速开始)
-[参考文献](#参考文献) -[参考文献](#参考文献)
## 简介 ## 简介
...@@ -73,10 +75,61 @@ StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects ...@@ -73,10 +75,61 @@ StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects
注意: pix2pix模型数据集准备中的list文件需要通过scripts文件夹里的make_pair_data.py来生成,可以使用以下命令来生成: 注意: pix2pix模型数据集准备中的list文件需要通过scripts文件夹里的make_pair_data.py来生成,可以使用以下命令来生成:
python scripts/make_pair_data.py \ python scripts/make_pair_data.py \
--direction=A2B --direction=A2B
用户可以通过指定direction参数生成list文件,从而确保图像风格转变的方向。 用户可以通过设置`--direction`参数生成list文件,从而确保图像风格转变的方向。
### 模型训练 ### 模型训练
**下载预训练模型:**
本示例提供以下预训练模型:
| Model| Pretrained model |
|:--- |:---|
| Pix2Pix | [Pix2Pix的预训练模型]() |
| CycleGAN | [CycleGAN的预训练模型]() |
| StarGAN | [StarGAN的预训练模型]() |
| AttGAN | [AttGAN的预训练模型]() |
| STGAN | [STGAN的预训练模型]() |
下载完预训练模型之后,通过设置infer.py中`--init_model`加载预训练模型,测试所需要的图片。
执行以下命令得到CyleGAN的预测结果:
python infer.py \
--model_net=CycleGAN \
--init_model=$(path_to_init_model) \
--image_size=256 \
--dataset_dir=$(path_to_data) \
--input_style=$(A_or_B) \
--net_G=$(generator_network) \
--g_base_dims=$(base_dim_of_generator)
效果如图所示:
执行以下命令得到Pix2Pix的预测结果:
python infer.py \
--model_net=Pix2pix \
--init_model=$(path_to_init_model) \
--image_size=256 \
--dataset_dir=$(path_to_data) \
--net_G=$(generator_network)
效果如图所示:
执行以下命令得到StarGAN,AttGAN和STGAN的预测结果:
python infer.py \
--model_net=$(StarGAN_or_AttGAN_or_STGAN) \
--init_model=$(path_to_init_model)\
--dataset_dir=$(path_to_data)
效果如图所示:
**开始训练:** 数据准备完毕后,可以通过一下方式启动训练: **开始训练:** 数据准备完毕后,可以通过一下方式启动训练:
python train.py \ python train.py \
--model_net=$(name_of_model) \ --model_net=$(name_of_model) \
--dataset=$(name_of_dataset) \ --dataset=$(name_of_dataset) \
...@@ -85,15 +138,23 @@ StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects ...@@ -85,15 +138,23 @@ StarGAN, AttGAN和STGAN所需要的[Celeba](http://mmlab.ie.cuhk.edu.hk/projects
--test_list=$(path_to_test_data_list) \ --test_list=$(path_to_test_data_list) \
--batch_size=$(batch_size) --batch_size=$(batch_size)
用户可以通过设置model_net参数来选择想要训练的模型,通过设置dataset参数来选择训练所需要的数据集。 - 可选参数见:
python train.py --help
- 每个GAN都给出了一份运行示例,放在scripts文件夹内,用户可以直接运行训练脚本快速开始训练。
- 用户可以通过设置model_net参数来选择想要训练的模型,通过设置dataset参数来选择训练所需要的数据集。
### 模型测试 ### 模型测试
模型测试是利用训练完成的生成模型进行图像生成。infer.py是主要的执行程序,调用示例如下: 模型测试是利用训练完成的生成模型进行图像生成。infer.py是主要的执行程序,调用示例如下:
python infer.py \ python infer.py \
--model_net=$(name_of_model) \ --model_net=$(name_of_model) \
--init_model=$(path_to_model) \ --init_model=$(path_to_model) \
--dataset_dir=$(path_to_data) --dataset_dir=$(path_to_data)
- 每个GAN都给出了一份测试示例,放在scripts文件夹内,用户可以直接运行测试脚本得到测试结果。
## 参考文献 ## 参考文献
[1] [Goodfellow, Ian J.; Pouget-Abadie, Jean; Mirza, Mehdi; Xu, Bing; Warde-Farley, David; Ozair, Sherjil; Courville, Aaron; Bengio, Yoshua. Generative Adversarial Networks. 2014. arXiv:1406.2661 [stat.ML].](https://arxiv.org/abs/1406.2661) [1] [Goodfellow, Ian J.; Pouget-Abadie, Jean; Mirza, Mehdi; Xu, Bing; Warde-Farley, David; Ozair, Sherjil; Courville, Aaron; Bengio, Yoshua. Generative Adversarial Networks. 2014. arXiv:1406.2661 [stat.ML].](https://arxiv.org/abs/1406.2661)
......
...@@ -26,6 +26,7 @@ import numpy as np ...@@ -26,6 +26,7 @@ import numpy as np
import imageio import imageio
import glob import glob
from util.config import add_arguments, print_arguments from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator
import copy import copy
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -33,14 +34,12 @@ add_arg = functools.partial(add_arguments, argparser=parser) ...@@ -33,14 +34,12 @@ add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('model_net', str, 'cgan', "The model used") add_arg('model_net', str, 'cgan', "The model used")
add_arg('net_G', str, "resnet_9block", "Choose the CycleGAN and Pix2pix generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]") add_arg('net_G', str, "resnet_9block", "Choose the CycleGAN and Pix2pix generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]")
add_arg('input', str, None, "The images to be infered.")
add_arg('init_model', str, None, "The init model file of directory.") add_arg('init_model', str, None, "The init model file of directory.")
add_arg('output', str, "./infer_result", "The directory the infer result to be saved to.") add_arg('output', str, "./infer_result", "The directory the infer result to be saved to.")
add_arg('input_style', str, "A", "The style of the input, A or B") add_arg('input_style', str, "A", "The style of the input, A or B")
add_arg('norm_type', str, "batch_norm", "Which normalization to used") add_arg('norm_type', str, "batch_norm", "Which normalization to used")
add_arg('use_gpu', bool, True, "Whether to use GPU to train.") add_arg('use_gpu', bool, True, "Whether to use GPU to train.")
add_arg('dropout', bool, False, "Whether to use dropout") add_arg('dropout', bool, False, "Whether to use dropout")
add_arg('data_shape', int, 256, "The shape of load image")
add_arg('g_base_dims', int, 64, "Base channels in CycleGAN generator") add_arg('g_base_dims', int, 64, "Base channels in CycleGAN generator")
add_arg('c_dim', int, 13, "the size of attrs") add_arg('c_dim', int, 13, "the size of attrs")
add_arg('use_gru', bool, False, "Whether to use GRU") add_arg('use_gru', bool, False, "Whether to use GRU")
...@@ -51,14 +50,14 @@ add_arg('selected_attrs', str, ...@@ -51,14 +50,14 @@ add_arg('selected_attrs', str,
"the attributes we selected to change") "the attributes we selected to change")
add_arg('batch_size', int, 16, "batch size when test") add_arg('batch_size', int, 16, "batch size when test")
add_arg('test_list', str, "./data/celeba/test_list_attr_celeba.txt", "the test list file") add_arg('test_list', str, "./data/celeba/test_list_attr_celeba.txt", "the test list file")
add_arg('dataset_dir', str, "./data/celeba/", "the dataset directory") add_arg('dataset_dir', str, "./data/celeba/", "the dataset directory to be infered")
add_arg('n_layers', int, 5, "default layers in generotor") add_arg('n_layers', int, 5, "default layers in generotor")
add_arg('gru_n_layers', int, 4, "default layers of GRU in generotor") add_arg('gru_n_layers', int, 4, "default layers of GRU in generotor")
# yapf: enable # yapf: enable
def infer(args): def infer(args):
data_shape = [-1, 3, args.data_shape, args.data_shape] data_shape = [-1, 3, args.image_size, args.image_size]
input = fluid.layers.data(name='input', shape=data_shape, dtype='float32') input = fluid.layers.data(name='input', shape=data_shape, dtype='float32')
label_org_ = fluid.layers.data( label_org_ = fluid.layers.data(
name='label_org_', shape=[args.c_dim], dtype='float32') name='label_org_', shape=[args.c_dim], dtype='float32')
...@@ -66,7 +65,7 @@ def infer(args): ...@@ -66,7 +65,7 @@ def infer(args):
name='label_trg_', shape=[args.c_dim], dtype='float32') name='label_trg_', shape=[args.c_dim], dtype='float32')
model_name = 'net_G' model_name = 'net_G'
if args.model_net == 'cyclegan': if args.model_net == 'CycleGAN':
from network.CycleGAN_network import CycleGAN_model from network.CycleGAN_network import CycleGAN_model
model = CycleGAN_model() model = CycleGAN_model()
if args.input_style == "A": if args.input_style == "A":
...@@ -136,10 +135,11 @@ def infer(args): ...@@ -136,10 +135,11 @@ def infer(args):
images = [real_img_temp] images = [real_img_temp]
for i in range(args.c_dim): for i in range(args.c_dim):
label_trg_tmp = copy.deepcopy(label_trg) label_trg_tmp = copy.deepcopy(label_trg)
for j in range(args.batch_size): for j in range(len(label_org)):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i] label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg_ = map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp) label_trg_ = list(
for j in range(args.batch_size): map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
for j in range(len(label_org)):
label_trg_[j][i] = label_trg_[j][i] * 2.0 label_trg_[j][i] = label_trg_[j][i] * 2.0
tensor_label_org_.set(label_org, place) tensor_label_org_.set(label_org, place)
tensor_label_trg.set(label_trg, place) tensor_label_trg.set(label_trg, place)
...@@ -149,7 +149,7 @@ def infer(args): ...@@ -149,7 +149,7 @@ def infer(args):
"label_org_": tensor_label_org_, "label_org_": tensor_label_org_,
"label_trg_": tensor_label_trg_ "label_trg_": tensor_label_trg_
}, },
fetch_list=fake.name) fetch_list=[fake.name])
fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1]) fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
images.append(fake_temp) images.append(fake_temp)
images_concat = np.concatenate(images, 1) images_concat = np.concatenate(images, 1)
...@@ -167,29 +167,33 @@ def infer(args): ...@@ -167,29 +167,33 @@ def infer(args):
args, shuffle=False, return_name=True) args, shuffle=False, return_name=True)
for data in zip(reader_test()): for data in zip(reader_test()):
real_img, label_org, name = data[0] real_img, label_org, name = data[0]
print("read {}".format(name))
tensor_img = fluid.LoDTensor() tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor() tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place) tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place) tensor_label_org.set(label_org, place)
real_img_temp = np.squeeze(real_img).transpose([1, 2, 0]) real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
images = [real_img_temp] images = [real_img_temp]
for i in range(cfg.c_dim): for i in range(args.c_dim):
label_trg = np.zeros([1, cfg.c_dim]).astype("float32") label_trg = np.zeros(
label_trg[0][i] = 1 [len(label_org), args.c_dim]).astype("float32")
for j in range(len(label_org)):
label_trg[j][i] = 1
tensor_label_trg = fluid.LoDTensor() tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place) tensor_label_trg.set(label_trg, place)
out = exe.run( out = exe.run(
feed={"input": tensor_img, feed={"input": tensor_img,
"label_trg_": tensor_label_trg}, "label_trg_": tensor_label_trg},
fetch_list=fake.name) fetch_list=[fake.name])
fake_temp = np.squeeze(out[0]).transpose([1, 2, 0]) fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
images.append(fake_temp) images.append(fake_temp)
images_concat = np.concatenate(images, 1) images_concat = np.concatenate(images, 1)
imageio.imwrite(out_path + "/fake_img" + str(epoch) + "_" + name[0], images_concat = np.concatenate(images_concat, 1)
((images_concat + 1) * 127.5).astype(np.uint8)) imageio.imwrite(args.output + "/fake_img_" + name[0], (
(images_concat + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'Pix2pix' or args.model_net == 'cyclegan': elif args.model_net == 'Pix2pix' or args.model_net == 'CycleGAN':
for file in glob.glob(args.input): for file in glob.glob(args.dataset_dir):
print("read {}".format(file)) print("read {}".format(file))
image_name = os.path.basename(file) image_name = os.path.basename(file)
image = Image.open(file).convert('RGB') image = Image.open(file).convert('RGB')
......
python infer.py --init_model output/checkpoints/199/ --input data/cityscapes/testA/* --input_style A --model_net cyclegan --net_G resnet_6block --g_base_dims 32 python infer.py --init_model output/checkpoints/199/ --dataset_dir "data/cityscapes/testA/*" --image_size 256 --input_style A --model_net CycleGAN --net_G resnet_6block --g_base_dims 32
python infer.py --init_model output/checkpoints/199/ --input "data/cityscapes/testB/*" --model_net Pix2pix --net_G unet_256 python infer.py --init_model output/checkpoints/199/ --image_size 256 --dataset_dir "data/cityscapes/testB/*" --model_net Pix2pix --net_G unet_256
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册