未验证 提交 811680c0 编写于 作者: L lvmengsi 提交者: GitHub

fix bn=1 when attgan infer (#2786)

* fix_bn=1
上级 6d1d55fd
......@@ -27,7 +27,7 @@ import imageio
import glob
from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator
from util.utility import check_attribute_conflict, check_gpu
from util.utility import check_attribute_conflict, check_gpu, save_batch_image
import copy
parser = argparse.ArgumentParser(description=__doc__)
......@@ -144,7 +144,8 @@ def infer(args):
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(args.c_dim):
label_trg_tmp = copy.deepcopy(label_trg)
......@@ -165,10 +166,11 @@ def infer(args):
"label_trg_": tensor_label_trg_
},
fetch_list=[fake.name])
fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
fake_temp = save_batch_image(out[0])
images.append(fake_temp)
images_concat = np.concatenate(images, 1)
images_concat = np.concatenate(images_concat, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(args.output + "/fake_img_" + name[0], (
(images_concat + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'StarGAN':
......@@ -187,7 +189,8 @@ def infer(args):
tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = np.squeeze(real_img).transpose([0, 2, 3, 1])
real_img_temp = save_batch_image(real_img)
images = [real_img_temp]
for i in range(args.c_dim):
label_trg_tmp = copy.deepcopy(label_org)
......@@ -201,10 +204,11 @@ def infer(args):
feed={"input": tensor_img,
"label_trg_": tensor_label_trg},
fetch_list=[fake.name])
fake_temp = np.squeeze(out[0]).transpose([0, 2, 3, 1])
fake_temp = save_batch_image(out[0])
images.append(fake_temp)
images_concat = np.concatenate(images, 1)
images_concat = np.concatenate(images_concat, 1)
if len(label_org) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(args.output + "/fake_img_" + name[0], (
(images_concat + 1) * 127.5).astype(np.uint8))
......
......@@ -114,8 +114,8 @@ def save_test_image(epoch,
for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_org)
label_trg_tmp[0][i] = 1.0 - label_trg_tmp[0][i]
label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_trg = check_attribute_conflict(label_trg_tmp,
attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place)
fake_temp, rec_temp = exe.run(
......@@ -266,20 +266,28 @@ def check_attribute_conflict(label_batch, attr, attrs):
return label_batch
def save_batch_image(img):
if len(img) == 1:
res_img = np.squeeze(img).transpose([1, 2, 0])
else:
res_img = np.squeeze(img).transpose([0, 2, 3, 1])
return res_img
def check_gpu(use_gpu):
"""
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
try:
if use_gpu and not fluid.is_compiled_with_cuda():
logger.error(err)
sys.exit(1)
except Exception as e:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册