未验证 提交 f01ee9b1 编写于 作者: L lvmengsi 提交者: GitHub

fix_infer_pix2pixreader (#3371)

* fix_infer_pix2pixreader
上级 2d3e8c5b
......@@ -67,11 +67,9 @@ def get_preprocess_param(load_width, load_height, crop_width, crop_height):
x = np.random.randint(0, np.maximum(0, load_width - crop_width))
y = np.random.randint(0, np.maximum(0, load_height - crop_height))
flip = np.random.rand() > 0.5
return {
"crop_pos": (x, y),
"flip": flip}
return {"crop_pos": (x, y), "flip": flip}
class reader_creator(object):
''' read and preprocess dataset'''
......@@ -108,7 +106,7 @@ class reader_creator(object):
if self.shuffle:
np.random.shuffle(self.lines)
for i, file in enumerate(self.lines):
file = file.strip('\n\r\t ')
self.name2id[os.path.basename(file)] = i
......@@ -256,17 +254,16 @@ class triplex_reader_creator(reader_creator):
print("files is not equal to 3!")
sys.exit(-1)
#label image instance
img1 = Image.open(os.path.join(self.image_dir, files[
0]))
img1 = Image.open(os.path.join(self.image_dir, files[0]))
img2 = Image.open(os.path.join(self.image_dir, files[
1])).convert('RGB')
if not args.no_instance:
img3 = Image.open(os.path.join(self.image_dir, files[
2]))
img3 = Image.open(os.path.join(self.image_dir, files[2]))
if self.mode == "TRAIN":
param = get_preprocess_param(args.load_width, args.load_height,
args.crop_width, args.crop_height)
param = get_preprocess_param(
args.load_width, args.load_height, args.crop_width,
args.crop_height)
img1 = img1.resize((args.load_width, args.load_height),
Image.NEAREST)
img2 = img2.resize((args.load_width, args.load_height),
......@@ -275,10 +272,13 @@ class triplex_reader_creator(reader_creator):
img3 = img3.resize((args.load_width, args.load_height),
Image.NEAREST)
if args.crop_type == 'Centor':
img1 = CentorCrop(img1, args.crop_width, args.crop_height)
img2 = CentorCrop(img2, args.crop_width, args.crop_height)
img1 = CentorCrop(img1, args.crop_width,
args.crop_height)
img2 = CentorCrop(img2, args.crop_width,
args.crop_height)
if not args.no_instance:
img3 = CentorCrop(img3, args.crop_width, args.crop_height)
img3 = CentorCrop(img3, args.crop_width,
args.crop_height)
elif args.crop_type == 'Random':
x = param['crop_pos'][0]
y = param['crop_pos'][1]
......@@ -287,8 +287,8 @@ class triplex_reader_creator(reader_creator):
img2 = img2.crop(
(x, y, x + args.crop_width, y + args.crop_height))
if not args.no_instance:
img3 = img3.crop(
(x, y, x + args.crop_width, y + args.crop_height))
img3 = img3.crop((x, y, x + args.crop_width,
y + args.crop_height))
else:
img1 = img1.resize((args.crop_width, args.crop_height),
Image.NEAREST)
......@@ -299,9 +299,10 @@ class triplex_reader_creator(reader_creator):
Image.NEAREST)
img1 = np.array(img1)
index = img1[np.newaxis, :,:]
input_label = np.zeros((args.label_nc, index.shape[1], index.shape[2]))
np.put_along_axis(input_label,index,1.0,0)
index = img1[np.newaxis, :, :]
input_label = np.zeros(
(args.label_nc, index.shape[1], index.shape[2]))
np.put_along_axis(input_label, index, 1.0, 0)
img1 = input_label
img2 = (np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5
img2 = img2.transpose([2, 0, 1])
......@@ -311,10 +312,14 @@ class triplex_reader_creator(reader_creator):
###extracte edge from instance
edge = np.zeros(img3.shape)
edge = edge.astype('int8')
edge[:, :, 1:] = edge[:, :, 1:] | (img3[:, :, 1:] != img3[:, :, :-1])
edge[:, :, :-1] = edge[:, :, :-1] | (img3[:, :, 1:] != img3[:, :, :-1])
edge[:, 1:, :] = edge[:, 1:, :] | (img3[:, 1:, :] != img3[:, :-1, :])
edge[:, :-1, :] = edge[:, :-1, :] | (img3[:, 1:, :] != img3[:, :-1, :])
edge[:, :, 1:] = edge[:, :, 1:] | (
img3[:, :, 1:] != img3[:, :, :-1])
edge[:, :, :-1] = edge[:, :, :-1] | (
img3[:, :, 1:] != img3[:, :, :-1])
edge[:, 1:, :] = edge[:, 1:, :] | (
img3[:, 1:, :] != img3[:, :-1, :])
edge[:, :-1, :] = edge[:, :-1, :] | (
img3[:, 1:, :] != img3[:, :-1, :])
img3 = edge.astype('float32')
###end extracte
batch_out_1.append(img1)
......@@ -594,9 +599,10 @@ class data_reader(object):
mode="TEST")
reader_test = test_reader.make_reader(
self.cfg, return_name=True)
id2name = test_reader.id2name
batch_num = train_reader.len()
reader = train_reader.make_reader(self.cfg)
return reader, reader_test, batch_num
return reader, reader_test, batch_num, id2name
elif self.cfg.model_net in ['SPADE']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt')
......
......@@ -26,7 +26,7 @@ import numpy as np
import imageio
import glob
from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator, reader_creator, triplex_reader_creato
from data_reader import celeba_reader_creator, reader_creator, triplex_reader_creator
from util.utility import check_attribute_conflict, check_gpu, save_batch_image
from util import utility
import copy
......@@ -170,8 +170,10 @@ def infer(args):
elif args.model_net == 'SPADE':
from network.SPADE_network import SPADE_model
model = SPADE_model()
input_label = fluid.layers.data(name='input_label', shape=data_shape, dtype='float32')
input_ins = fluid.layers.data(name='input_ins', shape=data_shape, dtype='float32')
input_label = fluid.layers.data(
name='input_label', shape=data_shape, dtype='float32')
input_ins = fluid.layers.data(
name='input_ins', shape=data_shape, dtype='float32')
input_ = fluid.layers.concat([input_label, input_ins], 1)
fake = model.network_G(input_, "generator", cfg=args, is_test=True)
else:
......@@ -316,8 +318,7 @@ def infer(args):
shuffle=False,
batch_size=1,
mode="TEST")
reader_test = test_reader.make_reader(
args, return_name=True)
reader_test = test_reader.make_reader(args, return_name=True)
for data in zip(reader_test()):
data_A, data_B, data_C, name = data[0]
name = name[0]
......
......@@ -342,12 +342,6 @@ class SPADE(object):
for tensor in py_reader():
data_A, data_B, data_C = tensor[0]['input_A'], tensor[0][
'input_B'], tensor[0]['input_C']
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_C = fluid.LoDTensor()
tensor_A.set(data_A, place)
tensor_B.set(data_B, place)
tensor_C.set(data_C, place)
s_time = time.time()
# optimize the generator network
g_loss_gan, g_loss_vgg, g_loss_feat, fake_B_tmp = exe.run(
......@@ -357,9 +351,9 @@ class SPADE(object):
gen_trainer.gan_feat_loss, gen_trainer.fake_B
],
feed={
"input_label": tensor_A,
"input_img": tensor_B,
"input_ins": tensor_C
"input_label": data_A,
"input_img": data_B,
"input_ins": data_C
})
# optimize the discriminator network
......@@ -369,9 +363,9 @@ class SPADE(object):
dis_trainer.gan_loss_real, dis_trainer.gan_loss_fake
],
feed={
"input_label": tensor_A,
"input_img": tensor_B,
"input_ins": tensor_C,
"input_label": data_A,
"input_img": data_B,
"input_ins": data_C,
"input_fake": fake_B_tmp
})
......
......@@ -172,28 +172,24 @@ def save_test_image(epoch,
res_inputB.save(os.path.join(out_path, inputB_name))
elif cfg.model_net == "SPADE":
for data in A_test_reader():
data_A, data_B, data_C, name = data[0]['input_A'], data[0]['input_B'], data[0]['input_C'], data[0]['image_name']
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_C = fluid.LoDTensor()
tensor_A.set(data_A, place)
tensor_B.set(data_B, place)
tensor_C.set(data_C, place)
fake_B_temp = exe.run(
test_program,
fetch_list=[g_trainer.fake_B],
feed={"input_label": tensor_A,
"input_img": tensor_B,
"input_ins": tensor_C})
data_A, data_B, data_C, name = data[0]['input_A'], data[0][
'input_B'], data[0]['input_C'], data[0]['image_name']
fake_B_temp = exe.run(test_program,
fetch_list=[g_trainer.fake_B],
feed={
"input_label": data_A,
"input_img": data_B,
"input_ins": data_C
})
fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0])
res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
np.uint8))
res_fakeB.save(out_path+"/fakeB_"+str(epoch)+"_"+name)
res_fakeB.save(out_path + "/fakeB_" + str(epoch) + "_" + name)
res_real = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
np.uint8))
res_real.save(out_path+"/real_"+str(epoch)+"_"+name)
res_real.save(out_path + "/real_" + str(epoch) + "_" + name)
elif cfg.model_net == "StarGAN":
for data in A_test_reader():
real_img, label_org, label_trg, image_name = data[0][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册