未验证 提交 0c91dc90 编写于 作者: L lvmengsi 提交者: GitHub

Cherry pick gan0722 (#2891)

* fix_gan_gpu (#2864)

* fix dir and gpu

* infer_cgan_dcgan (#2806)

add run shell and infer shell for cgan and dcgan
update infer.py for gan and dcgan
上级 6953a22e
......@@ -159,9 +159,9 @@ if __name__ == '__main__':
if args.dataset == 'mnist':
print('Download dataset: {}'.format(args.dataset))
download_mnist('./data/')
download_mnist('data')
elif args.dataset in cycle_pix_dataset:
print('Download dataset: {}'.format(args.dataset))
download_cycle_pix('./data/', args.dataset)
download_cycle_pix(os.path.join('data', args.dataset))
else:
print('Please download by yourself, thanks')
......@@ -28,12 +28,17 @@ import glob
from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator
from util.utility import check_attribute_conflict, check_gpu, save_batch_image
from util import utility
import copy
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_net', str, 'cgan', "The model used")
add_arg('model_net', str, 'CGAN', "The model used")
add_arg('net_G', str, "resnet_9block", "Choose the CycleGAN and Pix2pix generator's network, choose in [resnet_9block|resnet_6block|unet_128|unet_256]")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('output', str, "./infer_result", "The directory the infer result to be saved to.")
......@@ -54,6 +59,7 @@ add_arg('test_list', str, "./data/celeba/test_list_attr_celeba.txt",
add_arg('dataset_dir', str, "./data/celeba/", "the dataset directory to be infered")
add_arg('n_layers', int, 5, "default layers in generotor")
add_arg('gru_n_layers', int, 4, "default layers of GRU in generotor")
add_arg('noise_size', int, 100, "the noise dimension")
# yapf: enable
......@@ -103,6 +109,22 @@ def infer(args):
cfg=args,
name='generator',
is_test=True)
elif args.model_net == 'CGAN':
noise = fluid.layers.data(
name='noise', shape=[args.noise_size], dtype='float32')
conditions = fluid.layers.data(
name='conditions', shape=[1], dtype='float32')
from network.CGAN_network import CGAN_model
model = CGAN_model()
fake = model.network_G(noise, conditions, name="G")
elif args.model_net == 'DCGAN':
noise = fluid.layers.data(
name='noise', shape=[args.noise_size], dtype='float32')
from network.DCGAN_network import DCGAN_model
model = DCGAN_model()
fake = model.network_G(noise, name="G")
else:
raise NotImplementedError("model_net {} is not support".format(
args.model_net))
......@@ -116,7 +138,7 @@ def infer(args):
for var in fluid.default_main_program().global_block().all_parameters():
print(var.name)
print(args.init_model + '/' + model_name)
fluid.io.load_persistables(exe, args.init_model + "/" + model_name)
fluid.io.load_persistables(exe, os.path.join(args.init_model, model_name))
print('load params done')
if not os.path.exists(args.output):
os.makedirs(args.output)
......@@ -230,6 +252,42 @@ def infer(args):
imageio.imwrite(args.output + "/fake_" + image_name, (
(fake_temp + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'CGAN':
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32')
label = np.random.randint(
0, 9, size=[args.batch_size, 1]).astype('float32')
noise_tensor = fluid.LoDTensor()
conditions_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place)
conditions_tensor.set(label, place)
fake_temp = exe.run(
fetch_list=[fake.name],
feed={"noise": noise_tensor,
"conditions": conditions_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fig = utility.plot(fake_image)
plt.savefig(
os.path.join(args.output, 'fake_cgan.png'), bbox_inches='tight')
plt.close(fig)
elif args.model_net == 'DCGAN':
noise_data = np.random.uniform(
low=-1.0, high=1.0,
size=[args.batch_size, args.noise_size]).astype('float32')
noise_tensor = fluid.LoDTensor()
noise_tensor.set(noise_data, place)
fake_temp = exe.run(fetch_list=[fake.name],
feed={"noise": noise_tensor})[0]
fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fig = utility.plot(fake_image)
plt.savefig(
os.path.join(args.output, 'fake_dcgan.png'), bbox_inches='tight')
plt.close(fig)
else:
raise NotImplementedError("model_net {} is not support".format(
args.model_net))
......
python infer.py --model_net CGAN --init_model ./output/checkpoints/9/ --batch_size 32 --noise_size 100
python infer.py --model_net DCGAN --init_model ./output/checkpoints/9/ --batch_size 32 --noise_size 100
python train.py --model_net CGAN --dataset mnist --noise_size 100 --batch_size 32 --epoch 10 >log_out 2>log_err
python train.py --model_net DCGAN --dataset mnist --noise_size 100 --batch_size 32 --epoch 10 >log_out 2>log_err
......@@ -111,7 +111,7 @@ class CGAN(object):
utility.init_checkpoints(self.cfg, exe, g_trainer, "net_G")
utility.init_checkpoints(self.cfg, exe, d_trainer, "net_D")
### memory optim
### memory optim
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
......@@ -181,7 +181,7 @@ class CGAN(object):
t_time += batch_time
if batch_id % self.cfg.print_freq == 0:
image_path = self.cfg.output + '/images'
image_path = os.path.join(self.cfg.output, 'images')
if not os.path.exists(image_path):
os.makedirs(image_path)
generate_const_image = exe.run(
......@@ -201,10 +201,9 @@ class CGAN(object):
batch_time))
plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id,
batch_id))
img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
plt.savefig(
'{}/{:04d}_{:04d}.png'.format(image_path, epoch_id,
batch_id),
bbox_inches='tight')
os.path.join(image_path, img_name), bbox_inches='tight')
plt.close(fig)
if self.cfg.save_checkpoints:
......
......@@ -107,7 +107,7 @@ class DCGAN(object):
utility.init_checkpoints(self.cfg, exe, g_trainer, "net_G")
utility.init_checkpoints(self.cfg, exe, d_trainer, "net_D")
### memory optim
### memory optim
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = False
......@@ -166,7 +166,7 @@ class DCGAN(object):
t_time += batch_time
if batch_id % self.cfg.print_freq == 0:
image_path = self.cfg.output + '/images'
image_path = os.path.join(self.cfg.output, 'images')
if not os.path.exists(image_path):
os.makedirs(image_path)
generate_const_image = exe.run(
......@@ -185,10 +185,9 @@ class DCGAN(object):
batch_time))
plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id,
batch_id))
img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
plt.savefig(
'{}/{:04d}_{:04d}.png'.format(image_path, epoch_id,
batch_id),
bbox_inches='tight')
os.path.join(image_path, img_name), bbox_inches='tight')
plt.close(fig)
if self.cfg.save_checkpoints:
......
......@@ -49,7 +49,7 @@ def plot(gen_data):
def checkpoints(epoch, cfg, exe, trainer, name):
output_path = cfg.output + '/checkpoints/' + str(epoch)
output_path = os.path.join(cfg.output, 'checkpoints', str(epoch))
if not os.path.exists(output_path):
os.makedirs(output_path)
fluid.io.save_persistables(
......@@ -75,7 +75,7 @@ def save_test_image(epoch,
g_trainer,
A_test_reader,
B_test_reader=None):
out_path = cfg.output + '/test'
out_path = os.path.join(cfg.output, 'test')
if not os.path.exists(out_path):
os.makedirs(out_path)
if cfg.model_net == "Pix2pix":
......@@ -242,6 +242,8 @@ class ImagePool(object):
def check_attribute_conflict(label_batch, attr, attrs):
''' Based on https://github.com/LynnHo/AttGAN-Tensorflow'''
def _set(label, value, attr):
if attr in attrs:
label[attrs.index(attr)] = value
......@@ -262,10 +264,6 @@ def check_attribute_conflict(label_batch, attr, attrs):
for a in ['Straight_Hair', 'Wavy_Hair']:
if a != attr:
_set(label, 0, a)
elif attr in ['Mustache', 'No_Beard'] and attrs[attr_id] != 0:
for a in ['Mustache', 'No_Beard']:
if a != attr:
_set(label, 0, a)
return label_batch
......@@ -290,7 +288,7 @@ def check_gpu(use_gpu):
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
print(err)
sys.exit(1)
except Exception as e:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册