未验证 提交 0bd09f64 编写于 作者: L lvmengsi 提交者: GitHub

fix_gan_gpu (#2864)

* fix dir and gpu
上级 87906f04
...@@ -159,9 +159,9 @@ if __name__ == '__main__': ...@@ -159,9 +159,9 @@ if __name__ == '__main__':
if args.dataset == 'mnist': if args.dataset == 'mnist':
print('Download dataset: {}'.format(args.dataset)) print('Download dataset: {}'.format(args.dataset))
download_mnist('./data/') download_mnist('data')
elif args.dataset in cycle_pix_dataset: elif args.dataset in cycle_pix_dataset:
print('Download dataset: {}'.format(args.dataset)) print('Download dataset: {}'.format(args.dataset))
download_cycle_pix('./data/', args.dataset) download_cycle_pix(os.path.join('data', args.dataset))
else: else:
print('Please download by yourself, thanks') print('Please download by yourself, thanks')
...@@ -138,7 +138,7 @@ def infer(args): ...@@ -138,7 +138,7 @@ def infer(args):
for var in fluid.default_main_program().global_block().all_parameters(): for var in fluid.default_main_program().global_block().all_parameters():
print(var.name) print(var.name)
print(args.init_model + '/' + model_name) print(args.init_model + '/' + model_name)
fluid.io.load_persistables(exe, args.init_model + "/" + model_name) fluid.io.load_persistables(exe, os.path.join(args.init_model, model_name))
print('load params done') print('load params done')
if not os.path.exists(args.output): if not os.path.exists(args.output):
os.makedirs(args.output) os.makedirs(args.output)
...@@ -270,7 +270,8 @@ def infer(args): ...@@ -270,7 +270,8 @@ def infer(args):
fake_image = np.reshape(fake_temp, (args.batch_size, -1)) fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fig = utility.plot(fake_image) fig = utility.plot(fake_image)
plt.savefig(args.output + '/fake_cgan.png', bbox_inches='tight') plt.savefig(
os.path.join(args.output, 'fake_cgan.png'), bbox_inches='tight')
plt.close(fig) plt.close(fig)
elif args.model_net == 'DCGAN': elif args.model_net == 'DCGAN':
...@@ -284,7 +285,8 @@ def infer(args): ...@@ -284,7 +285,8 @@ def infer(args):
fake_image = np.reshape(fake_temp, (args.batch_size, -1)) fake_image = np.reshape(fake_temp, (args.batch_size, -1))
fig = utility.plot(fake_image) fig = utility.plot(fake_image)
plt.savefig(args.output + '/fake_dcgan.png', bbox_inches='tight') plt.savefig(
os.path.join(args.output, '/fake_dcgan.png'), bbox_inches='tight')
plt.close(fig) plt.close(fig)
else: else:
raise NotImplementedError("model_net {} is not support".format( raise NotImplementedError("model_net {} is not support".format(
......
...@@ -181,7 +181,7 @@ class CGAN(object): ...@@ -181,7 +181,7 @@ class CGAN(object):
t_time += batch_time t_time += batch_time
if batch_id % self.cfg.print_freq == 0: if batch_id % self.cfg.print_freq == 0:
image_path = self.cfg.output + '/images' image_path = os.path.join(self.cfg.output, 'images')
if not os.path.exists(image_path): if not os.path.exists(image_path):
os.makedirs(image_path) os.makedirs(image_path)
generate_const_image = exe.run( generate_const_image = exe.run(
...@@ -201,10 +201,9 @@ class CGAN(object): ...@@ -201,10 +201,9 @@ class CGAN(object):
batch_time)) batch_time))
plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id, plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id,
batch_id)) batch_id))
img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
plt.savefig( plt.savefig(
'{}/{:04d}_{:04d}.png'.format(image_path, epoch_id, os.path.join(image_path, img_name), bbox_inches='tight')
batch_id),
bbox_inches='tight')
plt.close(fig) plt.close(fig)
if self.cfg.save_checkpoints: if self.cfg.save_checkpoints:
......
...@@ -166,7 +166,7 @@ class DCGAN(object): ...@@ -166,7 +166,7 @@ class DCGAN(object):
t_time += batch_time t_time += batch_time
if batch_id % self.cfg.print_freq == 0: if batch_id % self.cfg.print_freq == 0:
image_path = self.cfg.output + '/images' image_path = os.path.join(self.cfg.output, 'images')
if not os.path.exists(image_path): if not os.path.exists(image_path):
os.makedirs(image_path) os.makedirs(image_path)
generate_const_image = exe.run( generate_const_image = exe.run(
...@@ -185,10 +185,9 @@ class DCGAN(object): ...@@ -185,10 +185,9 @@ class DCGAN(object):
batch_time)) batch_time))
plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id, plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id,
batch_id)) batch_id))
img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
plt.savefig( plt.savefig(
'{}/{:04d}_{:04d}.png'.format(image_path, epoch_id, os.path.join(image_path, img_name), bbox_inches='tight')
batch_id),
bbox_inches='tight')
plt.close(fig) plt.close(fig)
if self.cfg.save_checkpoints: if self.cfg.save_checkpoints:
......
...@@ -49,7 +49,7 @@ def plot(gen_data): ...@@ -49,7 +49,7 @@ def plot(gen_data):
def checkpoints(epoch, cfg, exe, trainer, name): def checkpoints(epoch, cfg, exe, trainer, name):
output_path = cfg.output + '/checkpoints/' + str(epoch) output_path = os.path.join(cfg.output, 'checkpoints', str(epoch))
if not os.path.exists(output_path): if not os.path.exists(output_path):
os.makedirs(output_path) os.makedirs(output_path)
fluid.io.save_persistables( fluid.io.save_persistables(
...@@ -75,7 +75,7 @@ def save_test_image(epoch, ...@@ -75,7 +75,7 @@ def save_test_image(epoch,
g_trainer, g_trainer,
A_test_reader, A_test_reader,
B_test_reader=None): B_test_reader=None):
out_path = cfg.output + '/test' out_path = os.path.join(cfg.output, 'test')
if not os.path.exists(out_path): if not os.path.exists(out_path):
os.makedirs(out_path) os.makedirs(out_path)
if cfg.model_net == "Pix2pix": if cfg.model_net == "Pix2pix":
...@@ -242,6 +242,8 @@ class ImagePool(object): ...@@ -242,6 +242,8 @@ class ImagePool(object):
def check_attribute_conflict(label_batch, attr, attrs): def check_attribute_conflict(label_batch, attr, attrs):
''' Based on https://github.com/LynnHo/AttGAN-Tensorflow'''
def _set(label, value, attr): def _set(label, value, attr):
if attr in attrs: if attr in attrs:
label[attrs.index(attr)] = value label[attrs.index(attr)] = value
...@@ -262,10 +264,6 @@ def check_attribute_conflict(label_batch, attr, attrs): ...@@ -262,10 +264,6 @@ def check_attribute_conflict(label_batch, attr, attrs):
for a in ['Straight_Hair', 'Wavy_Hair']: for a in ['Straight_Hair', 'Wavy_Hair']:
if a != attr: if a != attr:
_set(label, 0, a) _set(label, 0, a)
elif attr in ['Mustache', 'No_Beard'] and attrs[attr_id] != 0:
for a in ['Mustache', 'No_Beard']:
if a != attr:
_set(label, 0, a)
return label_batch return label_batch
...@@ -290,7 +288,7 @@ def check_gpu(use_gpu): ...@@ -290,7 +288,7 @@ def check_gpu(use_gpu):
try: try:
if use_gpu and not fluid.is_compiled_with_cuda(): if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err) print(err)
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
pass pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册