未验证 提交 e14d5bc1 编写于 作者: L lvmengsi 提交者: GitHub

fix infer (#2707)

上级 9d690da1
......@@ -237,7 +237,7 @@ STGAN的网络结构[9]
**Q:** 为什么STGAN和ATTGAN中变男性得到的预测结果是变女性呢?
**A:** 这是由于预测时标签的设置,目标标签是基于原本的标签进行改变,比如原本图片是男生,预测代码对标签进行转变的时候会自动变成相对立的标签,即女
性,所以得到的结果是女生。如果想要原本是男生,转变之后还是男生,可以参考模型库中预测代码的StarGAN的标签设置
性,所以得到的结果是女生。如果想要原本是男生,转变之后还是男生,保持要转变的标签不变即可
## 参考论文
......
......@@ -122,6 +122,8 @@ def infer(args):
if not os.path.exists(args.output):
os.makedirs(args.output)
attr_names = args.selected_attrs.split(',')
if args.model_net == 'AttGAN' or args.model_net == 'STGAN':
test_reader = celeba_reader_creator(
image_dir=args.dataset_dir,
......@@ -133,7 +135,6 @@ def infer(args):
args, shuffle=False, return_name=True)
for data in zip(reader_test()):
real_img, label_org, name = data[0]
attr_names = args.selected_attrs.split(',')
print("read {}".format(name))
label_trg = copy.deepcopy(label_org)
tensor_img = fluid.LoDTensor()
......@@ -189,10 +190,11 @@ def infer(args):
real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
images = [real_img_temp]
for i in range(args.c_dim):
label_trg = np.zeros(
[len(label_org), args.c_dim]).astype("float32")
label_trg_tmp = copy.deepcopy(label_org)
for j in range(len(label_org)):
label_trg[j][i] = 1
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place)
out = exe.run(
......
......@@ -104,6 +104,7 @@ def save_test_image(epoch,
elif cfg.model_net == "StarGAN":
for data in zip(A_test_reader()):
real_img, label_org, name = data[0]
attr_names = cfg.selected_attrs.split(',')
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place)
......@@ -111,8 +112,10 @@ def save_test_image(epoch,
real_img_temp = np.squeeze(real_img).transpose([1, 2, 0])
images = [real_img_temp]
for i in range(cfg.c_dim):
label_trg = np.zeros([1, cfg.c_dim]).astype("float32")
label_trg[0][i] = 1
label_trg_tmp = copy.deepcopy(label_org)
label_trg_tmp[0][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place)
fake_temp, rec_temp = exe.run(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册