未验证 提交 63d68a80 编写于 作者: L lvmengsi 提交者: GitHub

Cherry pick gan (#2737)

Cherry pick gan
上级 ee491be3
......@@ -18,7 +18,9 @@
本图像生成模型库包含CGAN\[[3](#参考文献)\], DCGAN\[[4](#参考文献)\], Pix2Pix\[[5](#参考文献)\], CycleGAN\[[6](#参考文献)\], StarGAN\[[7](#参考文献)\], AttGAN\[[8](#参考文献)\], STGAN\[[9](#参考文献)\]
注意:AttGAN和STGAN的网络结构中,判别器去掉了instance norm。
注意:
1. AttGAN和STGAN的网络结构中,判别器去掉了instance norm。
2. StarGAN,AttGAN和STGAN由于梯度惩罚所需的操作目前只支持GPU,需使用GPU训练。
图像生成模型库库的目录结构如下:
```
......@@ -235,7 +237,7 @@ STGAN的网络结构[9]
**Q:** 为什么STGAN和ATTGAN中变男性得到的预测结果是变女性呢?
**A:** 这是由于预测时标签的设置,目标标签是基于原本的标签进行改变,比如原本图片是男生,预测代码对标签进行转变的时候会自动变成相对立的标签,即女
性,所以得到的结果是女生。如果想要原本是男生,转变之后还是男生,可以参考模型库中预测代码的StarGAN的标签设置
性,所以得到的结果是女生。如果想要原本是男生,转变之后还是男生,保持要转变的标签不变即可
## 参考论文
......
......@@ -27,7 +27,7 @@ import imageio
import glob
from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator
from util.utility import check_attribute_conflict
from util.utility import check_attribute_conflict, check_gpu
import copy
parser = argparse.ArgumentParser(description=__doc__)
......@@ -122,6 +122,8 @@ def infer(args):
if not os.path.exists(args.output):
os.makedirs(args.output)
attr_names = args.selected_attrs.split(',')
if args.model_net == 'AttGAN' or args.model_net == 'STGAN':
test_reader = celeba_reader_creator(
image_dir=args.dataset_dir,
......@@ -133,7 +135,6 @@ def infer(args):
args, shuffle=False, return_name=True)
for data in zip(reader_test()):
real_img, label_org, name = data[0]
attr_names = args.selected_attrs.split(',')
print("read {}".format(name))
label_trg = copy.deepcopy(label_org)
tensor_img = fluid.LoDTensor()
......@@ -189,10 +190,11 @@ def infer(args):
real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
images = [real_img_temp]
for i in range(args.c_dim):
label_trg = np.zeros(
[len(label_org), args.c_dim]).astype("float32")
label_trg_tmp = copy.deepcopy(label_org)
for j in range(len(label_org)):
label_trg[j][i] = 1
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place)
out = exe.run(
......@@ -233,4 +235,5 @@ def infer(args):
if __name__ == "__main__":
args = parser.parse_args()
print_arguments(args)
check_gpu(args.use_gpu)
infer(args)
......@@ -76,6 +76,7 @@ def train(cfg):
if __name__ == "__main__":
cfg = config.parse_args()
config.print_arguments(cfg)
utility.check_gpu(cfg.use_gpu)
#assert cfg.load_size >= cfg.crop_size, "Load Size CANNOT less than Crop Size!"
if cfg.profile:
if cfg.use_gpu:
......
......@@ -308,12 +308,8 @@ class StarGAN(object):
loss_name=dis_trainer.d_loss.name,
build_strategy=build_strategy)
#losses = [[], []]
t_time = 0
test_program = gen_trainer.infer_program
utility.save_test_image(0, self.cfg, exe, place, test_program,
gen_trainer, self.test_reader)
for epoch_id in range(self.cfg.epoch):
batch_id = 0
for i in range(self.batch_num):
......
......@@ -104,6 +104,7 @@ def save_test_image(epoch,
elif cfg.model_net == "StarGAN":
for data in zip(A_test_reader()):
real_img, label_org, name = data[0]
attr_names = cfg.selected_attrs.split(',')
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place)
......@@ -111,8 +112,10 @@ def save_test_image(epoch,
real_img_temp = np.squeeze(real_img).transpose([1, 2, 0])
images = [real_img_temp]
for i in range(cfg.c_dim):
label_trg = np.zeros([1, cfg.c_dim]).astype("float32")
label_trg[0][i] = 1
label_trg_tmp = copy.deepcopy(label_org)
label_trg_tmp[0][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place)
fake_temp, rec_temp = exe.run(
......@@ -261,3 +264,22 @@ def check_attribute_conflict(label_batch, attr, attrs):
if a != attr:
_set(label, 0, a)
return label_batch
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册