未验证 提交 811680c0 编写于 作者: L lvmengsi 提交者: GitHub

fix bn=1 when attgan infer (#2786)

* fix_bn=1
上级 6d1d55fd
......@@ -27,7 +27,7 @@ import imageio
import glob
from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator
from util.utility import check_attribute_conflict, check_gpu
from util.utility import check_attribute_conflict, check_gpu, save_batch_image
import copy
parser = argparse.ArgumentParser(description=__doc__)
......@@ -144,7 +144,8 @@ def infer(args):
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(args.c_dim):
label_trg_tmp = copy.deepcopy(label_trg)
......@@ -165,9 +166,10 @@ def infer(args):
"label_trg_": tensor_label_trg_
},
fetch_list=[fake.name])
fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
fake_temp = save_batch_image(out[0])
images.append(fake_temp)
images_concat = np.concatenate(images, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(args.output + "/fake_img_" + name[0], (
(images_concat + 1) * 127.5).astype(np.uint8))
......@@ -187,7 +189,8 @@ def infer(args):
tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(args.c_dim):
label_trg_tmp = copy.deepcopy(label_org)
......@@ -201,9 +204,10 @@ def infer(args):
feed={"input": tensor_img,
"label_trg_": tensor_label_trg},
fetch_list=[fake.name])
fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
fake_temp = save_batch_image(out[0])
images.append(fake_temp)
images_concat = np.concatenate(images, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(args.output + "/fake_img_" + name[0], (
(images_concat + 1) * 127.5).astype(np.uint8))
......
......@@ -114,8 +114,8 @@ def save_test_image(epoch,
for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_org)
label_trg_tmp[0][i] = 1.0 - label_trg_tmp[0][i]
label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_trg = check_attribute_conflict(label_trg_tmp,
attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place)
fake_temp, rec_temp = exe.run(
......@@ -266,6 +266,14 @@ def check_attribute_conflict(label_batch, attr, attrs):
return label_batch
def save_batch_image(img):
if len(img) == 1:
res_img = np.squeeze(img).transpose([1, 2, 0])
else:
res_img = np.squeeze(img).transpose([0, 2, 3, 1])
return res_img
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册