未验证 提交 28aeb7a4 编写于 作者: L lvmengsi 提交者: GitHub

infer_cgan_dcgan (#2806)

add run shell and infer shell for cgan and dcgan
update infer.py for gan and dcgan
上级 fd84c497
......@@ -28,12 +28,17 @@ import glob
from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator
from util.utility import check_attribute_conflict, check_gpu, save_batch_image
from util import utility
import copy
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_net', str, 'cgan', "The model used")
add_arg('model_net', str, 'CGAN', "The model used")
add_arg('net_G', str, "resnet_9block", "Choose the CycleGAN and Pix2pix generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('output', str, "./infer_result", "The directory the infer result to be saved to.")
......@@ -50,10 +55,11 @@ add_arg('selected_attrs', str,
"Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young",
"the attributes we selected to change")
add_arg('batch_size', int, 16, "batch size when test")
add_arg('test_list', str, "./data/celeba/test_list_attr_celeba.txt", "the test list file")
add_arg('test_list', str, "./data/celeba/test_list_attr_celeba.txt", "the test list file")
add_arg('dataset_dir', str, "./data/celeba/", "the dataset directory to be infered")
add_arg('n_layers', int, 5, "default layers in generotor")
add_arg('gru_n_layers', int, 4, "default layers of GRU in generotor")
add_arg('n_layers', int, 5, "default layers in generotor")
add_arg('gru_n_layers', int, 4, "default layers of GRU in generotor")
add_arg('noise_size', int, 100, "the noise dimension")
# yapf: enable
......@@ -103,6 +109,22 @@ def infer(args):
cfg=args,
name='generator',
is_test=True)
elif args.model_net == 'CGAN':
noise = fluid.layers.data(
name='noise', shape=[args.noise_size], dtype='float32')
conditions = fluid.layers.data(
name='conditions', shape=[1], dtype='float32')
from network.CGAN_network import CGAN_model
model = CGAN_model()
fake = model.network_G(noise, conditions, name="G")
elif args.model_net == 'DCGAN':
noise = fluid.layers.data(
name='noise', shape=[args.noise_size], dtype='float32')
from network.DCGAN_network import DCGAN_model
model = DCGAN_model()
fake = model.network_G(noise, name="G")
else:
raise NotImplementedError("model_net {} is not support".format(
args.model_net))
......@@ -230,6 +252,40 @@ def infer(args):
imageio.imwrite(args.output + "/fake_" + image_name, (
(fake_temp + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'CGAN':
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32')
label = np.random.randint(
0, 9, size=[args.batch_size, 1]).astype('float32')
noise_tensor = fluid.LoDTensor()
conditions_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place)
conditions_tensor.set(label, place)
fake_temp = exe.run(
fetch_list=[fake.name],
feed={"noise": noise_tensor,
"conditions": conditions_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fig = utility.plot(fake_image)
plt.savefig(args.output + '/fake_cgan.png', bbox_inches='tight')
plt.close(fig)
elif args.model_net == 'DCGAN':
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32')
noise_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place)
fake_temp = exe.run(fetch_list=[fake.name],
feed={"noise": noise_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fig = utility.plot(fake_image)
plt.savefig(args.output + '/fake_dcgan.png', bbox_inches='tight')
plt.close(fig)
else:
raise NotImplementedError("model_net {} is not support".format(
args.model_net))
......
python infer.py --model_net CGAN --init_model ./output/checkpoints/9/ --batch_size 32 --noise_size 100
python infer.py --model_net DCGAN --init_model ./output/checkpoints/9/ --batch_size 32 --noise_size 100
python train.py --model_net CGAN --dataset mnist --noise_size 100 --batch_size 32 --epoch 10 >log_out 2>log_err
python train.py --model_net DCGAN --dataset mnist --noise_size 100 --batch_size 32 --epoch 10 >log_out 2>log_err
......@@ -111,7 +111,7 @@ class CGAN(object):
utility.init_checkpoints(self.cfg, exe, g_trainer, "net_G")
utility.init_checkpoints(self.cfg, exe, d_trainer, "net_D")
### memory optim
### memory optim
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
......
......@@ -107,7 +107,7 @@ class DCGAN(object):
utility.init_checkpoints(self.cfg, exe, g_trainer, "net_G")
utility.init_checkpoints(self.cfg, exe, d_trainer, "net_D")
### memory optim
### memory optim
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册