未验证 提交 63d68a80 编写于 作者: L lvmengsi 提交者: GitHub

Cherry pick gan (#2737)

Cherry pick gan
上级 ee491be3
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
本图像生成模型库包含CGAN\[[3](#参考文献)\], DCGAN\[[4](#参考文献)\], Pix2Pix\[[5](#参考文献)\], CycleGAN\[[6](#参考文献)\], StarGAN\[[7](#参考文献)\], AttGAN\[[8](#参考文献)\], STGAN\[[9](#参考文献)\] 本图像生成模型库包含CGAN\[[3](#参考文献)\], DCGAN\[[4](#参考文献)\], Pix2Pix\[[5](#参考文献)\], CycleGAN\[[6](#参考文献)\], StarGAN\[[7](#参考文献)\], AttGAN\[[8](#参考文献)\], STGAN\[[9](#参考文献)\]
注意:AttGAN和STGAN的网络结构中,判别器去掉了instance norm。 注意:
1. AttGAN和STGAN的网络结构中,判别器去掉了instance norm。
2. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。
图像生成模型库库的目录结构如下: 图像生成模型库库的目录结构如下:
``` ```
...@@ -235,7 +237,7 @@ STGAN的网络结构[9] ...@@ -235,7 +237,7 @@ STGAN的网络结构[9]
**Q:** 为什么STGAN和ATTGAN中变男性得到的预测结果是变女性呢? **Q:** 为什么STGAN和ATTGAN中变男性得到的预测结果是变女性呢?
**A:** 这是由于预测时标签的设置,目标标签是基于原本的标签进行改变,比如原本图片是男生,预测代码对标签进行转变的时候会自动变成相对立的标签,即女 **A:** 这是由于预测时标签的设置,目标标签是基于原本的标签进行改变,比如原本图片是男生,预测代码对标签进行转变的时候会自动变成相对立的标签,即女
性,所以得到的结果是女生。如果想要原本是男生,转变之后还是男生,可以参考模型库中预测代码的StarGAN的标签设置 性,所以得到的结果是女生。如果想要原本是男生,转变之后还是男生,保持要转变的标签不变即可
## 参考论文 ## 参考论文
......
...@@ -27,7 +27,7 @@ import imageio ...@@ -27,7 +27,7 @@ import imageio
import glob import glob
from util.config import add_arguments, print_arguments from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator from data_reader import celeba_reader_creator
from util.utility import check_attribute_conflict from util.utility import check_attribute_conflict, check_gpu
import copy import copy
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
...@@ -122,6 +122,8 @@ def infer(args): ...@@ -122,6 +122,8 @@ def infer(args):
if not os.path.exists(args.output): if not os.path.exists(args.output):
os.makedirs(args.output) os.makedirs(args.output)
attr_names = args.selected_attrs.split(',')
if args.model_net == 'AttGAN' or args.model_net == 'STGAN': if args.model_net == 'AttGAN' or args.model_net == 'STGAN':
test_reader = celeba_reader_creator( test_reader = celeba_reader_creator(
image_dir=args.dataset_dir, image_dir=args.dataset_dir,
...@@ -133,7 +135,6 @@ def infer(args): ...@@ -133,7 +135,6 @@ def infer(args):
args, shuffle=False, return_name=True) args, shuffle=False, return_name=True)
for data in zip(reader_test()): for data in zip(reader_test()):
real_img, label_org, name = data[0] real_img, label_org, name = data[0]
attr_names = args.selected_attrs.split(',')
print("read {}".format(name)) print("read {}".format(name))
label_trg = copy.deepcopy(label_org) label_trg = copy.deepcopy(label_org)
tensor_img = fluid.LoDTensor() tensor_img = fluid.LoDTensor()
...@@ -189,10 +190,11 @@ def infer(args): ...@@ -189,10 +190,11 @@ def infer(args):
real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1]) real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
images = [real_img_temp] images = [real_img_temp]
for i in range(args.c_dim): for i in range(args.c_dim):
label_trg = np.zeros( label_trg_tmp = copy.deepcopy(label_org)
[len(label_org), args.c_dim]).astype("float32")
for j in range(len(label_org)): for j in range(len(label_org)):
label_trg[j][i] = 1 label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor() tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place) tensor_label_trg.set(label_trg, place)
out = exe.run( out = exe.run(
...@@ -233,4 +235,5 @@ def infer(args): ...@@ -233,4 +235,5 @@ def infer(args):
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
check_gpu(args.use_gpu)
infer(args) infer(args)
...@@ -76,6 +76,7 @@ def train(cfg): ...@@ -76,6 +76,7 @@ def train(cfg):
if __name__ == "__main__": if __name__ == "__main__":
cfg = config.parse_args() cfg = config.parse_args()
config.print_arguments(cfg) config.print_arguments(cfg)
utility.check_gpu(cfg.use_gpu)
#assert cfg.load_size >= cfg.crop_size, "Load Size CANNOT less than Crop Size!" #assert cfg.load_size >= cfg.crop_size, "Load Size CANNOT less than Crop Size!"
if cfg.profile: if cfg.profile:
if cfg.use_gpu: if cfg.use_gpu:
......
...@@ -308,12 +308,8 @@ class StarGAN(object): ...@@ -308,12 +308,8 @@ class StarGAN(object):
loss_name=dis_trainer.d_loss.name, loss_name=dis_trainer.d_loss.name,
build_strategy=build_strategy) build_strategy=build_strategy)
#losses = [[], []]
t_time = 0 t_time = 0
test_program = gen_trainer.infer_program
utility.save_test_image(0, self.cfg, exe, place, test_program,
gen_trainer, self.test_reader)
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for i in range(self.batch_num): for i in range(self.batch_num):
......
...@@ -104,6 +104,7 @@ def save_test_image(epoch, ...@@ -104,6 +104,7 @@ def save_test_image(epoch,
elif cfg.model_net == "StarGAN": elif cfg.model_net == "StarGAN":
for data in zip(A_test_reader()): for data in zip(A_test_reader()):
real_img, label_org, name = data[0] real_img, label_org, name = data[0]
attr_names = cfg.selected_attrs.split(',')
tensor_img = fluid.LoDTensor() tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor() tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place) tensor_img.set(real_img, place)
...@@ -111,8 +112,10 @@ def save_test_image(epoch, ...@@ -111,8 +112,10 @@ def save_test_image(epoch,
real_img_temp = np.squeeze(real_img).transpose([1, 2, 0]) real_img_temp = np.squeeze(real_img).transpose([1, 2, 0])
images = [real_img_temp] images = [real_img_temp]
for i in range(cfg.c_dim): for i in range(cfg.c_dim):
label_trg = np.zeros([1, cfg.c_dim]).astype("float32") label_trg_tmp = copy.deepcopy(label_org)
label_trg[0][i] = 1 label_trg_tmp[0][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor() tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place) tensor_label_trg.set(label_trg, place)
fake_temp, rec_temp = exe.run( fake_temp, rec_temp = exe.run(
...@@ -261,3 +264,22 @@ def check_attribute_conflict(label_batch, attr, attrs): ...@@ -261,3 +264,22 @@ def check_attribute_conflict(label_batch, attr, attrs):
if a != attr: if a != attr:
_set(label, 0, a) _set(label, 0, a)
return label_batch return label_batch
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册