未验证 提交 25932361 编写于 作者: L lvmengsi 提交者: GitHub

fix easy used (#3211)

* fix used
上级 50280922
...@@ -156,7 +156,7 @@ def conv2d(input, ...@@ -156,7 +156,7 @@ def conv2d(input,
if padding_type == "SAME": if padding_type == "SAME":
top_padding, bottom_padding = cal_padding(input.shape[2], stride, top_padding, bottom_padding = cal_padding(input.shape[2], stride,
filter_size) filter_size)
left_padding, right_padding = cal_padding(input.shape[2], stride, left_padding, right_padding = cal_padding(input.shape[3], stride,
filter_size) filter_size)
height_padding = bottom_padding height_padding = bottom_padding
width_padding = right_padding width_padding = right_padding
...@@ -244,7 +244,7 @@ def deconv2d(input, ...@@ -244,7 +244,7 @@ def deconv2d(input,
if padding_type == "SAME": if padding_type == "SAME":
top_padding, bottom_padding = cal_padding(input.shape[2], stride, top_padding, bottom_padding = cal_padding(input.shape[2], stride,
filter_size) filter_size)
left_padding, right_padding = cal_padding(input.shape[2], stride, left_padding, right_padding = cal_padding(input.shape[3], stride,
filter_size) filter_size)
height_padding = bottom_padding height_padding = bottom_padding
width_padding = right_padding width_padding = right_padding
......
python infer.py --model_net AttGAN --init_model output/checkpoints/119/ --dataset_dir ./data/celeba/ --image_size 128 python infer.py --model_net AttGAN --init_model output/checkpoints/119/ --dataset_dir ./data/celeba/ --image_size 128 --output ./infer_result/attgan/
python infer.py --model_net CGAN --init_model ./output/checkpoints/19/ --n_samples 32 --noise_size 100 python infer.py --model_net CGAN --init_model ./output/checkpoints/19/ --n_samples 32 --noise_size 100 --output ./infer_result/c_gan/
python infer.py --init_model output/checkpoints/199/ --dataset_dir data/cityscapes/ --image_size 256 --n_samples 1 --crop_size 256 --input_style B --test_list ./data/cityscapes/testB.txt --model_net CycleGAN --net_G resnet_9block --g_base_dims 32 python infer.py --init_model output/checkpoints/199/ --dataset_dir data/cityscapes/ --image_size 256 --n_samples 1 --crop_size 256 --input_style B --test_list ./data/cityscapes/testB.txt --model_net CycleGAN --net_G resnet_9block --g_base_dims 32 --output ./infer_result/cyclegan/
python infer.py --model_net DCGAN --init_model ./output/checkpoints/19/ --n_samples 32 --noise_size 100 python infer.py --model_net DCGAN --init_model ./output/checkpoints/19/ --n_samples 32 --noise_size 100 --output ./infer_result/dcgan/
python infer.py --init_model output/checkpoints/199/ --image_size 256 --n_samples 1 --crop_size 256 --dataset_dir data/cityscapes/ --model_net Pix2pix --net_G unet_256 --test_list data/cityscapes/testB.txt python infer.py --init_model output/checkpoints/199/ --image_size 256 --n_samples 1 --crop_size 256 --dataset_dir data/cityscapes/ --model_net Pix2pix --net_G unet_256 --test_list data/cityscapes/testB.txt --output ./infer_result/pix2pix/
python infer.py --model_net StarGAN --init_model ./output/checkpoints/19/ --dataset_dir ./data/celeba/ --image_size 128 --c_dim 5 --selected_attrs "Black_Hair,Blond_Hair,Brown_Hair,Male,Young" python infer.py --model_net StarGAN --init_model ./output/checkpoints/19/ --dataset_dir ./data/celeba/ --image_size 128 --c_dim 5 --selected_attrs "Black_Hair,Blond_Hair,Brown_Hair,Male,Young" --output ./infer_result/stargan/
python infer.py --model_net STGAN --init_model ./output/checkpoints/19/ --dataset_dir ./data/celeba/ --image_size 128 --use_gru True python infer.py --model_net STGAN --init_model ./output/checkpoints/19/ --dataset_dir ./data/celeba/ --image_size 128 --use_gru True --output ./infer_result/stgan/
python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --batch_size 16 --epoch 20 > log_out 2>log_err python train.py --model_net StarGAN --dataset celeba --crop_size 178 --image_size 128 --train_list ./data/celeba/list_attr_celeba.txt --batch_size 16 --epoch 20 --gan_mode wgan > log_out 2>log_err
...@@ -55,6 +55,9 @@ class GTrainer(): ...@@ -55,6 +55,9 @@ class GTrainer():
fluid.layers.square( fluid.layers.square(
fluid.layers.elementwise_sub( fluid.layers.elementwise_sub(
x=self.pred_fake, y=ones))) x=self.pred_fake, y=ones)))
else:
raise NotImplementedError("gan_mode {} is not support!".format(
cfg.gan_mode))
self.g_loss_cls = fluid.layers.mean( self.g_loss_cls = fluid.layers.mean(
fluid.layers.sigmoid_cross_entropy_with_logits(self.cls_fake, fluid.layers.sigmoid_cross_entropy_with_logits(self.cls_fake,
...@@ -126,6 +129,9 @@ class DTrainer(): ...@@ -126,6 +129,9 @@ class DTrainer():
cfg=cfg, cfg=cfg,
name="discriminator") name="discriminator")
self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp
else:
raise NotImplementedError("gan_mode {} is not support!".format(
cfg.gan_mode))
self.d_loss_real.persistable = True self.d_loss_real.persistable = True
self.d_loss_fake.persistable = True self.d_loss_fake.persistable = True
...@@ -153,11 +159,11 @@ class DTrainer(): ...@@ -153,11 +159,11 @@ class DTrainer():
beta = fluid.layers.uniform_random_batch_size_like( beta = fluid.layers.uniform_random_batch_size_like(
input=a, shape=a.shape, min=0.0, max=1.0) input=a, shape=a.shape, min=0.0, max=1.0)
mean = fluid.layers.reduce_mean( mean = fluid.layers.reduce_mean(
a, range(len(a.shape)), keep_dim=True) a, dim=list(range(len(a.shape))), keep_dim=True)
input_sub_mean = fluid.layers.elementwise_sub(a, mean, axis=0) input_sub_mean = fluid.layers.elementwise_sub(a, mean, axis=0)
var = fluid.layers.reduce_mean( var = fluid.layers.reduce_mean(
fluid.layers.square(input_sub_mean), fluid.layers.square(input_sub_mean),
range(len(a.shape)), dim=list(range(len(a.shape))),
keep_dim=True) keep_dim=True)
b = beta * fluid.layers.sqrt(var) * 0.5 + a b = beta * fluid.layers.sqrt(var) * 0.5 + a
shape = [a.shape[0]] shape = [a.shape[0]]
...@@ -297,7 +303,10 @@ class AttGAN(object): ...@@ -297,7 +303,10 @@ class AttGAN(object):
# prepare environment # prepare environment
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
py_reader.decorate_batch_generator(self.train_reader, places=place) py_reader.decorate_batch_generator(
self.train_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -371,7 +380,9 @@ class AttGAN(object): ...@@ -371,7 +380,9 @@ class AttGAN(object):
iterable=True, iterable=True,
use_double_buffer=True) use_double_buffer=True)
test_py_reader.decorate_batch_generator( test_py_reader.decorate_batch_generator(
self.test_reader, places=place) self.test_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
test_program = test_gen_trainer.infer_program test_program = test_gen_trainer.infer_program
utility.save_test_image(epoch_id, self.cfg, exe, place, utility.save_test_image(epoch_id, self.cfg, exe, place,
......
...@@ -122,8 +122,11 @@ class CGAN(object): ...@@ -122,8 +122,11 @@ class CGAN(object):
d_trainer.program).with_data_parallel( d_trainer.program).with_data_parallel(
loss_name=d_trainer.d_loss.name, build_strategy=build_strategy) loss_name=d_trainer.d_loss.name, build_strategy=build_strategy)
if self.cfg.run_test:
image_path = os.path.join(self.cfg.output, 'images')
if not os.path.exists(image_path):
os.makedirs(image_path)
t_time = 0 t_time = 0
losses = [[], []]
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
for batch_id, data in enumerate(self.train_reader()): for batch_id, data in enumerate(self.train_reader()):
if len(data) != self.cfg.batch_size: if len(data) != self.cfg.batch_size:
...@@ -159,13 +162,12 @@ class CGAN(object): ...@@ -159,13 +162,12 @@ class CGAN(object):
fetch_list=[d_trainer.d_loss])[0] fetch_list=[d_trainer.d_loss])[0]
d_fake_loss = exe.run(d_trainer_program, d_fake_loss = exe.run(d_trainer_program,
feed={ feed={
'img': generate_image, 'img': generate_image[0],
'condition': condition_data, 'condition': condition_data,
'label': fake_label 'label': fake_label
}, },
fetch_list=[d_trainer.d_loss])[0] fetch_list=[d_trainer.d_loss])[0]
d_loss = d_real_loss + d_fake_loss d_loss = d_real_loss + d_fake_loss
losses[1].append(d_loss)
for _ in six.moves.xrange(self.cfg.num_generator_time): for _ in six.moves.xrange(self.cfg.num_generator_time):
g_loss = exe.run(g_trainer_program, g_loss = exe.run(g_trainer_program,
...@@ -174,15 +176,16 @@ class CGAN(object): ...@@ -174,15 +176,16 @@ class CGAN(object):
'condition': condition_data 'condition': condition_data
}, },
fetch_list=[g_trainer.g_loss])[0] fetch_list=[g_trainer.g_loss])[0]
losses[0].append(g_loss)
batch_time = time.time() - s_time batch_time = time.time() - s_time
if batch_id % self.cfg.print_freq == 0:
print(
'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'.
format(epoch_id, batch_id, d_loss[0], g_loss[0],
batch_time))
t_time += batch_time t_time += batch_time
if batch_id % self.cfg.print_freq == 0: if self.cfg.run_test:
image_path = os.path.join(self.cfg.output, 'images')
if not os.path.exists(image_path):
os.makedirs(image_path)
generate_const_image = exe.run( generate_const_image = exe.run(
g_trainer.infer_program, g_trainer.infer_program,
feed={'noise': const_n, feed={'noise': const_n,
...@@ -194,10 +197,6 @@ class CGAN(object): ...@@ -194,10 +197,6 @@ class CGAN(object):
total_images = np.concatenate( total_images = np.concatenate(
[real_image, generate_image_reshape]) [real_image, generate_image_reshape])
fig = utility.plot(total_images) fig = utility.plot(total_images)
print(
'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'.
format(epoch_id, batch_id, d_loss[0], g_loss[0],
batch_time))
plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id, plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id,
batch_id)) batch_id))
img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id) img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
......
...@@ -259,8 +259,14 @@ class CycleGAN(object): ...@@ -259,8 +259,14 @@ class CycleGAN(object):
# prepare environment # prepare environment
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
A_py_reader.decorate_batch_generator(self.A_reader, places=place) A_py_reader.decorate_batch_generator(
B_py_reader.decorate_batch_generator(self.B_reader, places=place) self.A_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
B_py_reader.decorate_batch_generator(
self.B_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -359,9 +365,13 @@ class CycleGAN(object): ...@@ -359,9 +365,13 @@ class CycleGAN(object):
use_double_buffer=True) use_double_buffer=True)
A_test_py_reader.decorate_batch_generator( A_test_py_reader.decorate_batch_generator(
self.A_test_reader, places=place) self.A_test_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
B_test_py_reader.decorate_batch_generator( B_test_py_reader.decorate_batch_generator(
self.B_test_reader, places=place) self.B_test_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
test_program = gen_trainer.infer_program test_program = gen_trainer.infer_program
utility.save_test_image( utility.save_test_image(
epoch_id, epoch_id,
......
...@@ -95,7 +95,7 @@ class DCGAN(object): ...@@ -95,7 +95,7 @@ class DCGAN(object):
d_trainer = DTrainer(img, label, self.cfg) d_trainer = DTrainer(img, label, self.cfg)
# prepare enviorment # prepare enviorment
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -118,6 +118,11 @@ class DCGAN(object): ...@@ -118,6 +118,11 @@ class DCGAN(object):
d_trainer.program).with_data_parallel( d_trainer.program).with_data_parallel(
loss_name=d_trainer.d_loss.name, build_strategy=build_strategy) loss_name=d_trainer.d_loss.name, build_strategy=build_strategy)
if self.cfg.run_test:
image_path = os.path.join(self.cfg.output, 'images')
if not os.path.exists(image_path):
os.makedirs(image_path)
t_time = 0 t_time = 0
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
for batch_id, data in enumerate(self.train_reader()): for batch_id, data in enumerate(self.train_reader()):
...@@ -148,7 +153,7 @@ class DCGAN(object): ...@@ -148,7 +153,7 @@ class DCGAN(object):
fetch_list=[d_trainer.d_loss])[0] fetch_list=[d_trainer.d_loss])[0]
d_fake_loss = exe.run( d_fake_loss = exe.run(
d_trainer_program, d_trainer_program,
feed={'img': generate_image, feed={'img': generate_image[0],
'label': fake_label}, 'label': fake_label},
fetch_list=[d_trainer.d_loss])[0] fetch_list=[d_trainer.d_loss])[0]
d_loss = d_real_loss + d_fake_loss d_loss = d_real_loss + d_fake_loss
...@@ -164,12 +169,16 @@ class DCGAN(object): ...@@ -164,12 +169,16 @@ class DCGAN(object):
fetch_list=[g_trainer.g_loss])[0] fetch_list=[g_trainer.g_loss])[0]
batch_time = time.time() - s_time batch_time = time.time() - s_time
t_time += batch_time
if batch_id % self.cfg.print_freq == 0: if batch_id % self.cfg.print_freq == 0:
image_path = os.path.join(self.cfg.output, 'images') print(
if not os.path.exists(image_path): 'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'.
os.makedirs(image_path) format(epoch_id, batch_id, d_loss[0], g_loss[0],
batch_time))
t_time += batch_time
if self.cfg.run_test:
generate_const_image = exe.run( generate_const_image = exe.run(
g_trainer.infer_program, g_trainer.infer_program,
feed={'noise': const_n}, feed={'noise': const_n},
...@@ -181,10 +190,6 @@ class DCGAN(object): ...@@ -181,10 +190,6 @@ class DCGAN(object):
[real_image, generate_image_reshape]) [real_image, generate_image_reshape])
fig = utility.plot(total_images) fig = utility.plot(total_images)
print(
'Epoch ID: {} Batch ID: {} D_loss: {} G_loss: {} Batch_time_cost: {}'.
format(epoch_id, batch_id, d_loss[0], g_loss[0],
batch_time))
plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id, plt.title('Epoch ID={}, Batch ID={}'.format(epoch_id,
batch_id)) batch_id))
img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id) img_name = '{:04d}_{:04d}.png'.format(epoch_id, batch_id)
......
...@@ -56,6 +56,9 @@ class GTrainer(): ...@@ -56,6 +56,9 @@ class GTrainer():
self.g_loss_gan = fluid.layers.mean( self.g_loss_gan = fluid.layers.mean(
fluid.layers.sigmoid_cross_entropy_with_logits( fluid.layers.sigmoid_cross_entropy_with_logits(
x=self.pred, label=ones)) x=self.pred, label=ones))
else:
raise NotImplementedError("gan_mode {} is not support!".format(
cfg.gan_mode))
self.g_loss_L1 = fluid.layers.reduce_mean( self.g_loss_L1 = fluid.layers.reduce_mean(
fluid.layers.abs( fluid.layers.abs(
...@@ -140,6 +143,10 @@ class DTrainer(): ...@@ -140,6 +143,10 @@ class DTrainer():
self.d_loss_fake = fluid.layers.mean( self.d_loss_fake = fluid.layers.mean(
fluid.layers.sigmoid_cross_entropy_with_logits( fluid.layers.sigmoid_cross_entropy_with_logits(
x=self.pred_fake, label=zeros)) x=self.pred_fake, label=zeros))
else:
raise NotImplementedError("gan_mode {} is not support!".format(
cfg.gan_mode))
self.d_loss = 0.5 * (self.d_loss_real + self.d_loss_fake) self.d_loss = 0.5 * (self.d_loss_real + self.d_loss_fake)
vars = [] vars = []
for var in self.program.list_vars(): for var in self.program.list_vars():
...@@ -225,7 +232,10 @@ class Pix2pix(object): ...@@ -225,7 +232,10 @@ class Pix2pix(object):
# prepare environment # prepare environment
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
py_reader.decorate_batch_generator(self.train_reader, places=place) py_reader.decorate_batch_generator(
self.train_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -299,7 +309,9 @@ class Pix2pix(object): ...@@ -299,7 +309,9 @@ class Pix2pix(object):
iterable=True, iterable=True,
use_double_buffer=True) use_double_buffer=True)
test_py_reader.decorate_batch_generator( test_py_reader.decorate_batch_generator(
self.test_reader, places=place) self.test_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
test_program = gen_trainer.infer_program test_program = gen_trainer.infer_program
utility.save_test_image( utility.save_test_image(
epoch_id, epoch_id,
......
...@@ -54,6 +54,9 @@ class GTrainer(): ...@@ -54,6 +54,9 @@ class GTrainer():
fluid.layers.square( fluid.layers.square(
fluid.layers.elementwise_sub( fluid.layers.elementwise_sub(
x=self.pred_fake, y=ones))) x=self.pred_fake, y=ones)))
else:
raise NotImplementedError("gan_mode {} is not support!".format(
cfg.gan_mode))
self.g_loss_cls = fluid.layers.mean( self.g_loss_cls = fluid.layers.mean(
fluid.layers.sigmoid_cross_entropy_with_logits(self.cls_fake, fluid.layers.sigmoid_cross_entropy_with_logits(self.cls_fake,
...@@ -128,6 +131,9 @@ class DTrainer(): ...@@ -128,6 +131,9 @@ class DTrainer():
cfg=cfg, cfg=cfg,
name="discriminator") name="discriminator")
self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp self.d_loss = self.d_loss_real + self.d_loss_fake + 1.0 * self.d_loss_cls + cfg.lambda_gp * self.d_loss_gp
else:
raise NotImplementedError("gan_mode {} is not support!".format(
cfg.gan_mode))
self.d_loss_real.persistable = True self.d_loss_real.persistable = True
self.d_loss_fake.persistable = True self.d_loss_fake.persistable = True
...@@ -159,11 +165,11 @@ class DTrainer(): ...@@ -159,11 +165,11 @@ class DTrainer():
beta = fluid.layers.uniform_random_batch_size_like( beta = fluid.layers.uniform_random_batch_size_like(
input=a, shape=a.shape, min=0.0, max=1.0) input=a, shape=a.shape, min=0.0, max=1.0)
mean = fluid.layers.reduce_mean( mean = fluid.layers.reduce_mean(
a, range(len(a.shape)), keep_dim=True) a, dim=list(range(len(a.shape))), keep_dim=True)
input_sub_mean = fluid.layers.elementwise_sub(a, mean, axis=0) input_sub_mean = fluid.layers.elementwise_sub(a, mean, axis=0)
var = fluid.layers.reduce_mean( var = fluid.layers.reduce_mean(
fluid.layers.square(input_sub_mean), fluid.layers.square(input_sub_mean),
range(len(a.shape)), dim=list(range(len(a.shape))),
keep_dim=True) keep_dim=True)
b = beta * fluid.layers.sqrt(var) * 0.5 + a b = beta * fluid.layers.sqrt(var) * 0.5 + a
shape = [a.shape[0]] shape = [a.shape[0]]
...@@ -308,7 +314,10 @@ class STGAN(object): ...@@ -308,7 +314,10 @@ class STGAN(object):
# prepare environment # prepare environment
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
py_reader.decorate_batch_generator(self.train_reader, places=place) py_reader.decorate_batch_generator(
self.train_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -380,7 +389,9 @@ class STGAN(object): ...@@ -380,7 +389,9 @@ class STGAN(object):
iterable=True, iterable=True,
use_double_buffer=True) use_double_buffer=True)
test_py_reader.decorate_batch_generator( test_py_reader.decorate_batch_generator(
self.test_reader, places=place) self.test_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
test_program = test_gen_trainer.infer_program test_program = test_gen_trainer.infer_program
utility.save_test_image(epoch_id, self.cfg, exe, place, utility.save_test_image(epoch_id, self.cfg, exe, place,
test_program, test_gen_trainer, test_program, test_gen_trainer,
......
...@@ -42,6 +42,10 @@ class GTrainer(): ...@@ -42,6 +42,10 @@ class GTrainer():
x=image_real, y=self.rec_img))) x=image_real, y=self.rec_img)))
self.pred_fake, self.cls_fake = model.network_D( self.pred_fake, self.cls_fake = model.network_D(
self.fake_img, cfg, name="d_main") self.fake_img, cfg, name="d_main")
if cfg.gan_mode != 'wgan':
raise NotImplementedError(
"gan_mode {} is not support! only support wgan".format(
cfg.gan_mode))
#wgan #wgan
self.g_loss_fake = -1 * fluid.layers.mean(self.pred_fake) self.g_loss_fake = -1 * fluid.layers.mean(self.pred_fake)
...@@ -104,6 +108,10 @@ class DTrainer(): ...@@ -104,6 +108,10 @@ class DTrainer():
self.d_loss_cls = fluid.layers.reduce_sum( self.d_loss_cls = fluid.layers.reduce_sum(
fluid.layers.sigmoid_cross_entropy_with_logits( fluid.layers.sigmoid_cross_entropy_with_logits(
self.cls_real, label_org)) / cfg.batch_size self.cls_real, label_org)) / cfg.batch_size
if cfg.gan_mode != 'wgan':
raise NotImplementedError(
"gan_mode {} is not support! only support wgan".format(
cfg.gan_mode))
#wgan #wgan
self.d_loss_fake = fluid.layers.mean(self.pred_fake) self.d_loss_fake = fluid.layers.mean(self.pred_fake)
self.d_loss_real = -1 * fluid.layers.mean(self.pred_real) self.d_loss_real = -1 * fluid.layers.mean(self.pred_real)
...@@ -273,7 +281,10 @@ class StarGAN(object): ...@@ -273,7 +281,10 @@ class StarGAN(object):
# prepare environment # prepare environment
place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.cfg.use_gpu else fluid.CPUPlace()
py_reader.decorate_batch_generator(self.train_reader, places=place) py_reader.decorate_batch_generator(
self.train_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -346,7 +357,9 @@ class StarGAN(object): ...@@ -346,7 +357,9 @@ class StarGAN(object):
iterable=True, iterable=True,
use_double_buffer=True) use_double_buffer=True)
test_py_reader.decorate_batch_generator( test_py_reader.decorate_batch_generator(
self.test_reader, places=place) self.test_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
test_program = gen_trainer.infer_program test_program = gen_trainer.infer_program
utility.save_test_image(epoch_id, self.cfg, exe, place, utility.save_test_image(epoch_id, self.cfg, exe, place,
test_program, gen_trainer, test_program, gen_trainer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册