未验证 提交 0bd09f64 编写于 作者: L lvmengsi 提交者: GitHub

fix_gan_gpu (#2864)

* fix dir and gpu
上级 87906f04
......@@ -159,9 +159,9 @@ if __name__ == '__main__':
if args.dataset == 'mnist':
print('Download dataset: {}'.format(args.dataset))
download_mnist('./data/')
download_mnist('data')
elif args.dataset in cycle_pix_dataset:
print('Download dataset: {}'.format(args.dataset))
download_cycle_pix('./data/', args.dataset)
download_cycle_pix(os.path.join('data', args.dataset))
else:
print('Please download by yourself, thanks')
......@@ -138,7 +138,7 @@ def infer(args):
for var in fluid.default_main_program().global_block().all_parameters():
print(var.name)
print(args.init_model + '/' + model_name)
fluid.io.load_persistables(exe, args.init_model + "/" + model_name)
fluid.io.load_persistables(exe, os.path.join(args.init_model, model_name))
print('load params done')
if not os.path.exists(args.output):
os.makedirs(args.output)
......@@ -270,7 +270,8 @@ def infer(args):
fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fig = utility.plot(fake_image)
plt.savefig(args.output + '/fake_cgan.png', bbox_inches='tight')
plt.savefig(
os.path.join(args.output, 'fake_cgan.png'), bbox_inches='tight')
plt.close(fig)
elif args.model_net == 'DCGAN':
......@@ -284,7 +285,8 @@ def infer(args):
fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fig = utility.plot(fake_image)
plt.savefig(args.output + '/fake_dcgan.png', bbox_inches='tight')
plt.savefig(
os.path.join(args.output, '/fake_dcgan.png'), bbox_inches='tight')
plt.close(fig)
else:
raise NotImplementedError("model_net {} is not support".format(
......
......@@ -181,7 +181,7 @@ class CGAN(object):
t_time += batch_time
if batch_id % self.cfg.print_freq == 0:
image_path = self.cfg.output + '/images'
image_path = os.path.join(self.cfg.output, 'images')
if not os.path.exists(image_path):
os.makedirs(image_path)
generate_const_image = exe.run(
......@@ -201,10 +201,9 @@ class CGAN(object):
batch_time))
plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id,
batch_id))
img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
plt.savefig(
'{}/{:04d}_{:04d}.png'.format(image_path, epoch_id,
batch_id),
bbox_inches='tight')
os.path.join(image_path, img_name), bbox_inches='tight')
plt.close(fig)
if self.cfg.save_checkpoints:
......
......@@ -166,7 +166,7 @@ class DCGAN(object):
t_time += batch_time
if batch_id % self.cfg.print_freq == 0:
image_path = self.cfg.output + '/images'
image_path = os.path.join(self.cfg.output, 'images')
if not os.path.exists(image_path):
os.makedirs(image_path)
generate_const_image = exe.run(
......@@ -185,10 +185,9 @@ class DCGAN(object):
batch_time))
plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id,
batch_id))
img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
plt.savefig(
'{}/{:04d}_{:04d}.png'.format(image_path, epoch_id,
batch_id),
bbox_inches='tight')
os.path.join(image_path, img_name), bbox_inches='tight')
plt.close(fig)
if self.cfg.save_checkpoints:
......
......@@ -49,7 +49,7 @@ def plot(gen_data):
def checkpoints(epoch, cfg, exe, trainer, name):
output_path = cfg.output + '/checkpoints/' + str(epoch)
output_path = os.path.join(cfg.output, 'checkpoints', str(epoch))
if not os.path.exists(output_path):
os.makedirs(output_path)
fluid.io.save_persistables(
......@@ -75,7 +75,7 @@ def save_test_image(epoch,
g_trainer,
A_test_reader,
B_test_reader=None):
out_path = cfg.output + '/test'
out_path = os.path.join(cfg.output, 'test')
if not os.path.exists(out_path):
os.makedirs(out_path)
if cfg.model_net == "Pix2pix":
......@@ -242,6 +242,8 @@ class ImagePool(object):
def check_attribute_conflict(label_batch, attr, attrs):
''' Based on https://github.com/LynnHo/AttGAN-Tensorflow'''
def _set(label, value, attr):
if attr in attrs:
label[attrs.index(attr)] = value
......@@ -262,10 +264,6 @@ def check_attribute_conflict(label_batch, attr, attrs):
for a in ['Straight_Hair', 'Wavy_Hair']:
if a != attr:
_set(label, 0, a)
elif attr in ['Mustache', 'No_Beard'] and attrs[attr_id] != 0:
for a in ['Mustache', 'No_Beard']:
if a != attr:
_set(label, 0, a)
return label_batch
......@@ -290,7 +288,7 @@ def check_gpu(use_gpu):
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
print(err)
sys.exit(1)
except Exception as e:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册