未验证 提交 f01ee9b1 编写于 作者: L lvmengsi 提交者: GitHub

fix_infer_pix2pixreader (#3371)

* fix_infer_pix2pixreader
上级 2d3e8c5b
...@@ -67,11 +67,9 @@ def get_preprocess_param(load_width, load_height, crop_width, crop_height): ...@@ -67,11 +67,9 @@ def get_preprocess_param(load_width, load_height, crop_width, crop_height):
x = np.random.randint(0, np.maximum(0, load_width - crop_width)) x = np.random.randint(0, np.maximum(0, load_width - crop_width))
y = np.random.randint(0, np.maximum(0, load_height - crop_height)) y = np.random.randint(0, np.maximum(0, load_height - crop_height))
flip = np.random.rand() > 0.5 flip = np.random.rand() > 0.5
return { return {"crop_pos": (x, y), "flip": flip}
"crop_pos": (x, y),
"flip": flip}
class reader_creator(object): class reader_creator(object):
''' read and preprocess dataset''' ''' read and preprocess dataset'''
...@@ -108,7 +106,7 @@ class reader_creator(object): ...@@ -108,7 +106,7 @@ class reader_creator(object):
if self.shuffle: if self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
for i, file in enumerate(self.lines): for i, file in enumerate(self.lines):
file = file.strip('\n\r\t ') file = file.strip('\n\r\t ')
self.name2id[os.path.basename(file)] = i self.name2id[os.path.basename(file)] = i
...@@ -256,17 +254,16 @@ class triplex_reader_creator(reader_creator): ...@@ -256,17 +254,16 @@ class triplex_reader_creator(reader_creator):
print("files is not equal to 3!") print("files is not equal to 3!")
sys.exit(-1) sys.exit(-1)
#label image instance #label image instance
img1 = Image.open(os.path.join(self.image_dir, files[ img1 = Image.open(os.path.join(self.image_dir, files[0]))
0]))
img2 = Image.open(os.path.join(self.image_dir, files[ img2 = Image.open(os.path.join(self.image_dir, files[
1])).convert('RGB') 1])).convert('RGB')
if not args.no_instance: if not args.no_instance:
img3 = Image.open(os.path.join(self.image_dir, files[ img3 = Image.open(os.path.join(self.image_dir, files[2]))
2]))
if self.mode == "TRAIN": if self.mode == "TRAIN":
param = get_preprocess_param(args.load_width, args.load_height, param = get_preprocess_param(
args.crop_width, args.crop_height) args.load_width, args.load_height, args.crop_width,
args.crop_height)
img1 = img1.resize((args.load_width, args.load_height), img1 = img1.resize((args.load_width, args.load_height),
Image.NEAREST) Image.NEAREST)
img2 = img2.resize((args.load_width, args.load_height), img2 = img2.resize((args.load_width, args.load_height),
...@@ -275,10 +272,13 @@ class triplex_reader_creator(reader_creator): ...@@ -275,10 +272,13 @@ class triplex_reader_creator(reader_creator):
img3 = img3.resize((args.load_width, args.load_height), img3 = img3.resize((args.load_width, args.load_height),
Image.NEAREST) Image.NEAREST)
if args.crop_type == 'Centor': if args.crop_type == 'Centor':
img1 = CentorCrop(img1, args.crop_width, args.crop_height) img1 = CentorCrop(img1, args.crop_width,
img2 = CentorCrop(img2, args.crop_width, args.crop_height) args.crop_height)
img2 = CentorCrop(img2, args.crop_width,
args.crop_height)
if not args.no_instance: if not args.no_instance:
img3 = CentorCrop(img3, args.crop_width, args.crop_height) img3 = CentorCrop(img3, args.crop_width,
args.crop_height)
elif args.crop_type == 'Random': elif args.crop_type == 'Random':
x = param['crop_pos'][0] x = param['crop_pos'][0]
y = param['crop_pos'][1] y = param['crop_pos'][1]
...@@ -287,8 +287,8 @@ class triplex_reader_creator(reader_creator): ...@@ -287,8 +287,8 @@ class triplex_reader_creator(reader_creator):
img2 = img2.crop( img2 = img2.crop(
(x, y, x + args.crop_width, y + args.crop_height)) (x, y, x + args.crop_width, y + args.crop_height))
if not args.no_instance: if not args.no_instance:
img3 = img3.crop( img3 = img3.crop((x, y, x + args.crop_width,
(x, y, x + args.crop_width, y + args.crop_height)) y + args.crop_height))
else: else:
img1 = img1.resize((args.crop_width, args.crop_height), img1 = img1.resize((args.crop_width, args.crop_height),
Image.NEAREST) Image.NEAREST)
...@@ -299,9 +299,10 @@ class triplex_reader_creator(reader_creator): ...@@ -299,9 +299,10 @@ class triplex_reader_creator(reader_creator):
Image.NEAREST) Image.NEAREST)
img1 = np.array(img1) img1 = np.array(img1)
index = img1[np.newaxis, :,:] index = img1[np.newaxis, :, :]
input_label = np.zeros((args.label_nc, index.shape[1], index.shape[2])) input_label = np.zeros(
np.put_along_axis(input_label,index,1.0,0) (args.label_nc, index.shape[1], index.shape[2]))
np.put_along_axis(input_label, index, 1.0, 0)
img1 = input_label img1 = input_label
img2 = (np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5 img2 = (np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5
img2 = img2.transpose([2, 0, 1]) img2 = img2.transpose([2, 0, 1])
...@@ -311,10 +312,14 @@ class triplex_reader_creator(reader_creator): ...@@ -311,10 +312,14 @@ class triplex_reader_creator(reader_creator):
###extracte edge from instance ###extracte edge from instance
edge = np.zeros(img3.shape) edge = np.zeros(img3.shape)
edge = edge.astype('int8') edge = edge.astype('int8')
edge[:, :, 1:] = edge[:, :, 1:] | (img3[:, :, 1:] != img3[:, :, :-1]) edge[:, :, 1:] = edge[:, :, 1:] | (
edge[:, :, :-1] = edge[:, :, :-1] | (img3[:, :, 1:] != img3[:, :, :-1]) img3[:, :, 1:] != img3[:, :, :-1])
edge[:, 1:, :] = edge[:, 1:, :] | (img3[:, 1:, :] != img3[:, :-1, :]) edge[:, :, :-1] = edge[:, :, :-1] | (
edge[:, :-1, :] = edge[:, :-1, :] | (img3[:, 1:, :] != img3[:, :-1, :]) img3[:, :, 1:] != img3[:, :, :-1])
edge[:, 1:, :] = edge[:, 1:, :] | (
img3[:, 1:, :] != img3[:, :-1, :])
edge[:, :-1, :] = edge[:, :-1, :] | (
img3[:, 1:, :] != img3[:, :-1, :])
img3 = edge.astype('float32') img3 = edge.astype('float32')
###end extracte ###end extracte
batch_out_1.append(img1) batch_out_1.append(img1)
...@@ -594,9 +599,10 @@ class data_reader(object): ...@@ -594,9 +599,10 @@ class data_reader(object):
mode="TEST") mode="TEST")
reader_test = test_reader.make_reader( reader_test = test_reader.make_reader(
self.cfg, return_name=True) self.cfg, return_name=True)
id2name = test_reader.id2name
batch_num = train_reader.len() batch_num = train_reader.len()
reader = train_reader.make_reader(self.cfg) reader = train_reader.make_reader(self.cfg)
return reader, reader_test, batch_num return reader, reader_test, batch_num, id2name
elif self.cfg.model_net in ['SPADE']: elif self.cfg.model_net in ['SPADE']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset) dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt') train_list = os.path.join(dataset_dir, 'train.txt')
......
...@@ -26,7 +26,7 @@ import numpy as np ...@@ -26,7 +26,7 @@ import numpy as np
import imageio import imageio
import glob import glob
from util.config import add_arguments, print_arguments from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator, reader_creator, triplex_reader_creato from data_reader import celeba_reader_creator, reader_creator, triplex_reader_creator
from util.utility import check_attribute_conflict, check_gpu, save_batch_image from util.utility import check_attribute_conflict, check_gpu, save_batch_image
from util import utility from util import utility
import copy import copy
...@@ -170,8 +170,10 @@ def infer(args): ...@@ -170,8 +170,10 @@ def infer(args):
elif args.model_net == 'SPADE': elif args.model_net == 'SPADE':
from network.SPADE_network import SPADE_model from network.SPADE_network import SPADE_model
model = SPADE_model() model = SPADE_model()
input_label = fluid.layers.data(name='input_label', shape=data_shape, dtype='float32') input_label = fluid.layers.data(
input_ins = fluid.layers.data(name='input_ins', shape=data_shape, dtype='float32') name='input_label', shape=data_shape, dtype='float32')
input_ins = fluid.layers.data(
name='input_ins', shape=data_shape, dtype='float32')
input_ = fluid.layers.concat([input_label, input_ins], 1) input_ = fluid.layers.concat([input_label, input_ins], 1)
fake = model.network_G(input_, "generator", cfg=args, is_test=True) fake = model.network_G(input_, "generator", cfg=args, is_test=True)
else: else:
...@@ -316,8 +318,7 @@ def infer(args): ...@@ -316,8 +318,7 @@ def infer(args):
shuffle=False, shuffle=False,
batch_size=1, batch_size=1,
mode="TEST") mode="TEST")
reader_test = test_reader.make_reader( reader_test = test_reader.make_reader(args, return_name=True)
args, return_name=True)
for data in zip(reader_test()): for data in zip(reader_test()):
data_A, data_B, data_C, name = data[0] data_A, data_B, data_C, name = data[0]
name = name[0] name = name[0]
......
...@@ -342,12 +342,6 @@ class SPADE(object): ...@@ -342,12 +342,6 @@ class SPADE(object):
for tensor in py_reader(): for tensor in py_reader():
data_A, data_B, data_C = tensor[0]['input_A'], tensor[0][ data_A, data_B, data_C = tensor[0]['input_A'], tensor[0][
'input_B'], tensor[0]['input_C'] 'input_B'], tensor[0]['input_C']
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_C = fluid.LoDTensor()
tensor_A.set(data_A, place)
tensor_B.set(data_B, place)
tensor_C.set(data_C, place)
s_time = time.time() s_time = time.time()
# optimize the generator network # optimize the generator network
g_loss_gan, g_loss_vgg, g_loss_feat, fake_B_tmp = exe.run( g_loss_gan, g_loss_vgg, g_loss_feat, fake_B_tmp = exe.run(
...@@ -357,9 +351,9 @@ class SPADE(object): ...@@ -357,9 +351,9 @@ class SPADE(object):
gen_trainer.gan_feat_loss, gen_trainer.fake_B gen_trainer.gan_feat_loss, gen_trainer.fake_B
], ],
feed={ feed={
"input_label": tensor_A, "input_label": data_A,
"input_img": tensor_B, "input_img": data_B,
"input_ins": tensor_C "input_ins": data_C
}) })
# optimize the discriminator network # optimize the discriminator network
...@@ -369,9 +363,9 @@ class SPADE(object): ...@@ -369,9 +363,9 @@ class SPADE(object):
dis_trainer.gan_loss_real, dis_trainer.gan_loss_fake dis_trainer.gan_loss_real, dis_trainer.gan_loss_fake
], ],
feed={ feed={
"input_label": tensor_A, "input_label": data_A,
"input_img": tensor_B, "input_img": data_B,
"input_ins": tensor_C, "input_ins": data_C,
"input_fake": fake_B_tmp "input_fake": fake_B_tmp
}) })
......
...@@ -172,28 +172,24 @@ def save_test_image(epoch, ...@@ -172,28 +172,24 @@ def save_test_image(epoch,
res_inputB.save(os.path.join(out_path, inputB_name)) res_inputB.save(os.path.join(out_path, inputB_name))
elif cfg.model_net == "SPADE": elif cfg.model_net == "SPADE":
for data in A_test_reader(): for data in A_test_reader():
data_A, data_B, data_C, name = data[0]['input_A'], data[0]['input_B'], data[0]['input_C'], data[0]['image_name'] data_A, data_B, data_C, name = data[0]['input_A'], data[0][
tensor_A = fluid.LoDTensor() 'input_B'], data[0]['input_C'], data[0]['image_name']
tensor_B = fluid.LoDTensor() fake_B_temp = exe.run(test_program,
tensor_C = fluid.LoDTensor() fetch_list=[g_trainer.fake_B],
tensor_A.set(data_A, place) feed={
tensor_B.set(data_B, place) "input_label": data_A,
tensor_C.set(data_C, place) "input_img": data_B,
fake_B_temp = exe.run( "input_ins": data_C
test_program, })
fetch_list=[g_trainer.fake_B],
feed={"input_label": tensor_A,
"input_img": tensor_B,
"input_ins": tensor_C})
fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0]) fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0]) input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0])
res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype( res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
np.uint8)) np.uint8))
res_fakeB.save(out_path+"/fakeB_"+str(epoch)+"_"+name) res_fakeB.save(out_path + "/fakeB_" + str(epoch) + "_" + name)
res_real = Image.fromarray(((input_B_temp + 1) * 127.5).astype( res_real = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
np.uint8)) np.uint8))
res_real.save(out_path+"/real_"+str(epoch)+"_"+name) res_real.save(out_path + "/real_" + str(epoch) + "_" + name)
elif cfg.model_net == "StarGAN": elif cfg.model_net == "StarGAN":
for data in A_test_reader(): for data in A_test_reader():
real_img, label_org, label_trg, image_name = data[0][ real_img, label_org, label_trg, image_name = data[0][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册