提交 d5e25979 编写于 作者: L lifu

merge spade

上级 bd82886d
......@@ -80,27 +80,24 @@ class reader_creator(object):
list_filename,
shuffle=False,
batch_size=1,
drop_last=False,
mode="TRAIN"):
self.image_dir = image_dir
self.list_filename = list_filename
self.batch_size = batch_size
self.mode = mode
self.name2id = {}
self.id2name = {}
self.lines = open(self.list_filename).readlines()
if self.mode == "TRAIN":
self.shuffle = shuffle
self.drop_last = drop_last
else:
self.shuffle = False
self.drop_last = False
def len(self):
if self.drop_last or len(self.lines) % self.batch_size == 0:
return len(self.lines) // self.batch_size
else:
return len(self.lines) // self.batch_size + 1
return len(self.lines) // self.batch_size
def make_reader(self, args, return_name=False):
print(self.image_dir, self.list_filename)
......@@ -111,9 +108,11 @@ class reader_creator(object):
if self.shuffle:
np.random.shuffle(self.lines)
for file in self.lines:
for i, file in enumerate(self.lines):
file = file.strip('\n\r\t ')
self.name2id[os.path.basename(file)] = i
self.id2name[i] = os.path.basename(file)
img = Image.open(os.path.join(self.image_dir, file)).convert(
'RGB')
if self.mode == "TRAIN":
......@@ -131,7 +130,7 @@ class reader_creator(object):
if return_name:
batch_out.append(img)
batch_out_name.append(os.path.basename(file))
batch_out_name.append(i)
else:
batch_out.append(img)
if len(batch_out) == self.batch_size:
......@@ -139,13 +138,8 @@ class reader_creator(object):
yield batch_out, batch_out_name
batch_out_name = []
else:
yield batch_out
yield [batch_out]
batch_out = []
if self.drop_last == False and len(batch_out) != 0:
if return_name:
yield batch_out, batch_out_name
else:
yield batch_out
return reader
......@@ -158,14 +152,12 @@ class pair_reader_creator(reader_creator):
list_filename,
shuffle=False,
batch_size=1,
drop_last=False,
mode="TRAIN"):
super(pair_reader_creator, self).__init__(
image_dir,
list_filename,
shuffle=shuffle,
batch_size=batch_size,
drop_last=drop_last,
mode=mode)
def make_reader(self, args, return_name=False):
......@@ -177,13 +169,16 @@ class pair_reader_creator(reader_creator):
batch_out_name = []
if self.shuffle:
np.random.shuffle(self.lines)
for line in self.lines:
for i, line in enumerate(self.lines):
files = line.strip('\n\r\t ').split('\t')
img1 = Image.open(os.path.join(self.image_dir, files[
0])).convert('RGB')
img2 = Image.open(os.path.join(self.image_dir, files[
1])).convert('RGB')
self.name2id[os.path.basename(files[0])] = i
self.id2name[i] = os.path.basename(files[0])
if self.mode == "TRAIN":
param = get_preprocess_param(args.image_size,
args.crop_size)
......@@ -215,7 +210,7 @@ class pair_reader_creator(reader_creator):
batch_out_1.append(img1)
batch_out_2.append(img2)
if return_name:
batch_out_name.append(os.path.basename(files[0]))
batch_out_name.append(i)
if len(batch_out_1) == self.batch_size:
if return_name:
yield batch_out_1, batch_out_2, batch_out_name
......@@ -224,11 +219,6 @@ class pair_reader_creator(reader_creator):
yield batch_out_1, batch_out_2
batch_out_1 = []
batch_out_2 = []
if self.drop_last == False and len(batch_out_1) != 0:
if return_name:
yield batch_out_1, batch_out_2, batch_out_name
else:
yield batch_out_1, batch_out_2
return reader
......@@ -241,14 +231,12 @@ class triplex_reader_creator(reader_creator):
list_filename,
shuffle=False,
batch_size=1,
drop_last=False,
mode="TRAIN"):
super(triplex_reader_creator, self).__init__(
image_dir,
list_filename,
shuffle=shuffle,
batch_size=batch_size,
drop_last=drop_last,
mode=mode)
def make_reader(self, args, return_name=False):
......@@ -310,25 +298,13 @@ class triplex_reader_creator(reader_creator):
img3 = img3.resize((args.crop_width, args.crop_height),
Image.NEAREST)
###trans img1 to label
#input_label = np.zeros((args.label_nc, img1.size[1], img1.size[0]))
#for i in range(args.label_nc):
# input_label[i]=np.where(img1==i,1,0)
#img1 = input_label
img1 = np.array(img1)
index = img1[np.newaxis, :,:]
input_label = np.zeros((args.label_nc, index.shape[1], index.shape[2]))
np.put_along_axis(input_label,index,1.0,0)
#TODO:hard code
#input_label = np.ones((args.label_nc, index.shape[1], index.shape[2]))
img1 = input_label
#print(img1)
###end trans
#img1 = img1.transpose([2, 0, 1])
#print(np.array(img2))
img2 = (np.array(img2).astype('float32') / 255.0 - 0.5) / 0.5
img2 = img2.transpose([2, 0, 1])
#print(img2)
if not args.no_instance:
img3 = np.array(img3)[:, :, np.newaxis]
img3 = img3.transpose([2, 0, 1])
......@@ -341,9 +317,6 @@ class triplex_reader_creator(reader_creator):
edge[:, :-1, :] = edge[:, :-1, :] | (img3[:, 1:, :] != img3[:, :-1, :])
img3 = edge.astype('float32')
###end extracte
#print(img3)
#TODO:hard code
#img3 = np.ones((1, index.shape[1], index.shape[2]))
batch_out_1.append(img1)
batch_out_2.append(img2)
if not args.no_instance:
......@@ -365,17 +338,6 @@ class triplex_reader_creator(reader_creator):
batch_out_1 = []
batch_out_2 = []
batch_out_3 = []
if self.drop_last == False and len(batch_out_1) != 0:
if return_name:
if not args.no_instance:
yield batch_out_1, batch_out_2, batch_out_3, batch_out_name
else:
yield batch_out_1, batch_out_2, batch_out_name
else:
if not args.no_instance:
yield batch_out_1, batch_out_2, batch_out_3
else:
yield batch_out_1, batch_out_2
return reader
......@@ -399,17 +361,14 @@ class celeba_reader_creator(reader_creator):
attr2idx = {}
for i, attr_name in enumerate(all_attr_names):
attr2idx[attr_name] = i
lines = lines[2:]
if self.mode == "TRAIN":
self.batch_size = args.batch_size
self.drop_last = args.drop_last
self.shuffle = args.shuffle
lines = lines[2:train_end]
else:
self.batch_size = args.n_samples
self.shuffle = False
self.drop_last = False
if self.mode == "TEST":
lines = lines[train_end:test_end]
else:
......@@ -417,20 +376,17 @@ class celeba_reader_creator(reader_creator):
self.images = []
attr_names = args.selected_attrs.split(',')
for line in lines:
for i, line in enumerate(lines):
arr = line.strip().split()
name = os.path.join('img_align_celeba', arr[0])
label = []
for attr_name in attr_names:
idx = attr2idx[attr_name]
label.append(arr[idx + 1] == "1")
self.images.append((name, label))
self.images.append((name, label, arr[0]))
def len(self):
if self.drop_last or len(self.images) % self.batch_size == 0:
return len(self.images) // self.batch_size
else:
return len(self.images) // self.batch_size + 1
return len(self.images) // self.batch_size
def make_reader(self, return_name=False):
print(self.image_dir, self.list_filename)
......@@ -438,10 +394,11 @@ class celeba_reader_creator(reader_creator):
def reader():
batch_out_1 = []
batch_out_2 = []
batch_out_3 = []
batch_out_name = []
if self.shuffle:
np.random.shuffle(self.images)
for file, label in self.images:
for file, label, f_name in self.images:
img = Image.open(os.path.join(self.image_dir, file))
label = np.array(label).astype("float32")
if self.args.model_net == "StarGAN":
......@@ -455,20 +412,19 @@ class celeba_reader_creator(reader_creator):
batch_out_1.append(img)
batch_out_2.append(label)
if return_name:
batch_out_name.append(os.path.basename(file))
batch_out_name.append(int(f_name.split('.')[0]))
if len(batch_out_1) == self.batch_size:
batch_out_3 = np.copy(batch_out_2)
if self.shuffle:
np.random.shuffle(batch_out_3)
if return_name:
yield batch_out_1, batch_out_2, batch_out_name
yield batch_out_1, batch_out_2, batch_out_3, batch_out_name
batch_out_name = []
else:
yield batch_out_1, batch_out_2
yield batch_out_1, batch_out_2, batch_out_3
batch_out_1 = []
batch_out_2 = []
if self.drop_last == False and len(batch_out_1) != 0:
if return_name:
yield batch_out_1, batch_out_2, batch_out_name
else:
yield batch_out_1, batch_out_2
batch_out_3 = []
return reader
......@@ -549,17 +505,17 @@ class data_reader(object):
list_filename=trainA_list,
shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last,
mode="TRAIN")
b_train_reader = reader_creator(
image_dir=dataset_dir,
list_filename=trainB_list,
shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last,
mode="TRAIN")
a_reader_test = None
b_reader_test = None
a_id2name = None
b_id2name = None
if self.cfg.run_test:
testA_list = os.path.join(dataset_dir, "testA.txt")
testB_list = os.path.join(dataset_dir, "testB.txt")
......@@ -568,25 +524,25 @@ class data_reader(object):
list_filename=testA_list,
shuffle=False,
batch_size=1,
drop_last=self.cfg.drop_last,
mode="TEST")
b_test_reader = reader_creator(
image_dir=dataset_dir,
list_filename=testB_list,
shuffle=False,
batch_size=1,
drop_last=self.cfg.drop_last,
mode="TEST")
a_reader_test = a_test_reader.make_reader(
self.cfg, return_name=True)
b_reader_test = b_test_reader.make_reader(
self.cfg, return_name=True)
a_id2name = a_test_reader.id2name
b_id2name = b_test_reader.id2name
batch_num = max(a_train_reader.len(), b_train_reader.len())
a_reader = a_train_reader.make_reader(self.cfg)
b_reader = b_train_reader.make_reader(self.cfg)
return a_reader, b_reader, a_reader_test, b_reader_test, batch_num
return a_reader, b_reader, a_reader_test, b_reader_test, batch_num, a_id2name, b_id2name
elif self.cfg.model_net in ['StarGAN', 'STGAN', 'AttGAN']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
......@@ -611,7 +567,7 @@ class data_reader(object):
reader_test = test_reader.make_reader(return_name=True)
batch_num = train_reader.len()
reader = train_reader.make_reader()
return reader, reader_test, batch_num
return reader, reader_test, batch_num, None
elif self.cfg.model_net in ['Pix2pix']:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
......@@ -623,9 +579,9 @@ class data_reader(object):
list_filename=train_list,
shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last,
mode="TRAIN")
reader_test = None
id2name = None
if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt")
if self.cfg.test_list is not None:
......@@ -635,7 +591,6 @@ class data_reader(object):
list_filename=test_list,
shuffle=False,
batch_size=1,
drop_last=self.cfg.drop_last,
mode="TEST")
reader_test = test_reader.make_reader(
self.cfg, return_name=True)
......@@ -652,7 +607,6 @@ class data_reader(object):
list_filename=train_list,
shuffle=self.cfg.shuffle,
batch_size=self.cfg.batch_size,
drop_last=self.cfg.drop_last,
mode="TRAIN")
reader_test = None
if self.cfg.run_test:
......@@ -664,13 +618,13 @@ class data_reader(object):
list_filename=test_list,
shuffle=False,
batch_size=1,
drop_last=self.cfg.drop_last,
mode="TEST")
reader_test = test_reader.make_reader(
self.cfg, return_name=True)
id2name = test_reader.id2name
batch_num = train_reader.len()
reader = train_reader.make_reader(self.cfg)
return reader, reader_test, batch_num
return reader, reader_test, batch_num, id2name
else:
dataset_dir = os.path.join(self.cfg.data_dir, self.cfg.dataset)
train_list = os.path.join(dataset_dir, 'train.txt')
......@@ -679,14 +633,15 @@ class data_reader(object):
train_reader = reader_creator(
image_dir=dataset_dir, list_filename=train_list)
reader_test = None
id2name = None
if self.cfg.run_test:
test_list = os.path.join(dataset_dir, "test.txt")
test_reader = reader_creator(
image_dir=dataset_dir,
list_filename=test_list,
batch_size=self.cfg.n_samples,
drop_last=self.cfg.drop_last)
batch_size=self.cfg.n_samples)
reader_test = test_reader.get_test_reader(
self.cfg, shuffle=False, return_name=True)
id2name = test_reader.id2name
batch_num = train_reader.len()
return train_reader, reader_test, batch_num
return train_reader, reader_test, batch_num, id2name
......@@ -26,7 +26,7 @@ import numpy as np
import imageio
import glob
from util.config import add_arguments, print_arguments
from data_reader import celeba_reader_creator, triplex_reader_creator
from data_reader import celeba_reader_creator, reader_creator, triplex_reader_creato
from util.utility import check_attribute_conflict, check_gpu, save_batch_image
from util import utility
import copy
......@@ -78,9 +78,16 @@ def infer(args):
name='label_org_', shape=[args.c_dim], dtype='float32')
label_trg_ = fluid.layers.data(
name='label_trg_', shape=[args.c_dim], dtype='float32')
image_name = fluid.layers.data(
name='image_name', shape=[args.n_samples], dtype='int32')
model_name = 'net_G'
if args.model_net == 'CycleGAN':
py_reader = fluid.io.PyReader(
feed_list=[input, image_name],
capacity=4, ## batch_size * 4
iterable=True,
use_double_buffer=True)
from network.CycleGAN_network import CycleGAN_model
model = CycleGAN_model()
if args.input_style == "A":
......@@ -90,15 +97,35 @@ def infer(args):
else:
raise "Input with style [%s] is not supported." % args.input_style
elif args.model_net == 'Pix2pix':
py_reader = fluid.io.PyReader(
feed_list=[input, image_name],
capacity=4, ## batch_size * 4
iterable=True,
use_double_buffer=True)
from network.Pix2pix_network import Pix2pix_model
model = Pix2pix_model()
fake = model.network_G(input, "generator", cfg=args)
elif args.model_net == 'StarGAN':
py_reader = fluid.io.PyReader(
feed_list=[input, label_org_, label_trg_, image_name],
capacity=32,
iterable=True,
use_double_buffer=True)
from network.StarGAN_network import StarGAN_model
model = StarGAN_model()
fake = model.network_G(input, label_trg_, name="g_main", cfg=args)
elif args.model_net == 'STGAN':
from network.STGAN_network import STGAN_model
py_reader = fluid.io.PyReader(
feed_list=[input, label_org_, label_trg_, image_name],
capacity=32,
iterable=True,
use_double_buffer=True)
model = STGAN_model()
fake, _ = model.network_G(
input,
......@@ -109,6 +136,13 @@ def infer(args):
is_test=True)
elif args.model_net == 'AttGAN':
from network.AttGAN_network import AttGAN_model
py_reader = fluid.io.PyReader(
feed_list=[input, label_org_, label_trg_, image_name],
capacity=32,
iterable=True,
use_double_buffer=True)
model = AttGAN_model()
fake, _ = model.network_G(
input,
......@@ -124,14 +158,14 @@ def infer(args):
name='conditions', shape=[1], dtype='float32')
from network.CGAN_network import CGAN_model
model = CGAN_model()
model = CGAN_model(args.n_samples)
fake = model.network_G(noise, conditions, name="G")
elif args.model_net == 'DCGAN':
noise = fluid.layers.data(
name='noise', shape=[args.noise_size], dtype='float32')
from network.DCGAN_network import DCGAN_model
model = DCGAN_model()
model = DCGAN_model(args.n_samples)
fake = model.network_G(noise, name="G")
elif args.model_net == 'SPADE':
from network.SPADE_network import SPADE_model
......@@ -144,6 +178,13 @@ def infer(args):
raise NotImplementedError("model_net {} is not support".format(
args.model_net))
def _compute_start_end(image_name):
image_name_start = np.array(image_name)[0].astype('int32')
image_name_end = image_name_start + args.n_samples - 1
image_name_save = str(np.array(image_name)[0].astype('int32')) + '.jpg'
print("read {}.jpg ~ {}.jpg".format(image_name_start, image_name_end))
return image_name_save
# prepare environment
place = fluid.CPUPlace()
if args.use_gpu:
......@@ -167,36 +208,34 @@ def infer(args):
args=args,
mode="VAL")
reader_test = test_reader.make_reader(return_name=True)
for data in zip(reader_test()):
real_img, label_org, name = data[0]
print("read {}".format(name))
label_trg = copy.deepcopy(label_org)
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
tensor_label_trg = fluid.LoDTensor()
tensor_label_org_ = fluid.LoDTensor()
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = save_batch_image(real_img)
py_reader.decorate_batch_generator(
reader_test,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
for data in py_reader():
real_img, label_org, label_trg, image_name = data[0]['input'], data[
0]['label_org_'], data[0]['label_trg_'], data[0]['image_name']
image_name_save = _compute_start_end(image_name)
real_img_temp = save_batch_image(np.array(real_img))
images = [real_img_temp]
for i in range(args.c_dim):
label_trg_tmp = copy.deepcopy(label_trg)
for j in range(len(label_org)):
label_trg_tmp = copy.deepcopy(np.array(label_trg))
for j in range(len(label_trg_tmp)):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_org_ = list(map(lambda x: ((x * 2) - 1) * 0.5, label_org))
label_trg_ = list(
label_org_tmp = list(
map(lambda x: ((x * 2) - 1) * 0.5, np.array(label_org)))
label_trg_tmp = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
if args.model_net == 'AttGAN':
for k in range(len(label_org)):
label_trg_[k][i] = label_trg_[k][i] * 2.0
tensor_label_org_.set(label_org_, place)
tensor_label_trg.set(label_trg, place)
tensor_label_trg_.set(label_trg_, place)
for k in range(len(label_trg_tmp)):
label_trg_tmp[k][i] = label_trg_tmp[k][i] * 2.0
tensor_label_org_ = fluid.LoDTensor()
tensor_label_trg_ = fluid.LoDTensor()
tensor_label_org_.set(label_org_tmp, place)
tensor_label_trg_.set(label_trg_tmp, place)
out = exe.run(feed={
"input": tensor_img,
"input": real_img,
"label_org_": tensor_label_org_,
"label_trg_": tensor_label_trg_
},
......@@ -204,10 +243,11 @@ def infer(args):
fake_temp = save_batch_image(out[0])
images.append(fake_temp)
images_concat = np.concatenate(images, 1)
if len(label_org) > 1:
if len(np.array(label_org)) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(args.output + "/fake_img_" + name[0], (
(images_concat + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(args.output, "fake_img_" + image_name_save), (
(images_concat + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'StarGAN':
test_reader = celeba_reader_creator(
image_dir=args.dataset_dir,
......@@ -215,61 +255,66 @@ def infer(args):
args=args,
mode="VAL")
reader_test = test_reader.make_reader(return_name=True)
for data in zip(reader_test()):
real_img, label_org, name = data[0]
print("read {}".format(name))
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = save_batch_image(real_img)
py_reader.decorate_batch_generator(
reader_test,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
for data in py_reader():
real_img, label_org, label_trg, image_name = data[0]['input'], data[
0]['label_org_'], data[0]['label_trg_'], data[0]['image_name']
image_name_save = _compute_start_end(image_name)
real_img_temp = save_batch_image(np.array(real_img))
images = [real_img_temp]
for i in range(args.c_dim):
label_trg_tmp = copy.deepcopy(label_org)
for j in range(len(label_org)):
label_trg_tmp = copy.deepcopy(np.array(label_org))
for j in range(len(np.array(label_org))):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict(
label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place)
tensor_label_trg_ = fluid.LoDTensor()
tensor_label_trg_.set(label_trg_tmp, place)
out = exe.run(
feed={"input": tensor_img,
"label_trg_": tensor_label_trg},
feed={"input": real_img,
"label_trg_": tensor_label_trg_},
fetch_list=[fake.name])
fake_temp = save_batch_image(out[0])
images.append(fake_temp)
images_concat = np.concatenate(images, 1)
if len(label_org) > 1:
if len(np.array(label_org)) > 1:
images_concat = np.concatenate(images_concat, 1)
imageio.imwrite(args.output + "/fake_img_" + name[0], (
(images_concat + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(args.output, "fake_img_" + image_name_save), (
(images_concat + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'Pix2pix' or args.model_net == 'CycleGAN':
for file in glob.glob(args.dataset_dir):
print("read {}".format(file))
image_name = os.path.basename(file)
image = Image.open(file).convert('RGB')
image = image.resize((args.image_size, args.image_size), Image.BICUBIC)
image = np.array(image).transpose([2, 0, 1]).astype('float32')
image = image / 255.0
image = (image - 0.5) / 0.5
data = image[np.newaxis, :]
tensor = fluid.LoDTensor()
tensor.set(data, place)
fake_temp = exe.run(fetch_list=[fake.name], feed={"input": tensor})
test_reader = reader_creator(
image_dir=args.dataset_dir,
list_filename=args.test_list,
shuffle=False,
batch_size=args.n_samples,
mode="VAL")
reader_test = test_reader.make_reader(args, return_name=True)
py_reader.decorate_batch_generator(
reader_test,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
id2name = test_reader.id2name
for data in py_reader():
real_img, image_name = data[0]['input'], data[0]['image_name']
image_name = id2name[np.array(image_name).astype('int32')[0]]
print("read: ", image_name)
fake_temp = exe.run(fetch_list=[fake.name],
feed={"input": real_img})
fake_temp = np.squeeze(fake_temp[0]).transpose([1, 2, 0])
input_temp = np.squeeze(data).transpose([1, 2, 0])
input_temp = np.squeeze(np.array(real_img)[0]).transpose([1, 2, 0])
imageio.imwrite(args.output + "/fake_" + image_name, (
(fake_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
os.path.join(args.output, "fake_" + image_name), (
(fake_temp + 1) * 127.5).astype(np.uint8))
elif args.model_net == 'SPADE':
test_reader = triplex_reader_creator(
image_dir=args.dataset_dir,
list_filename=args.test_list,
shuffle=False,
batch_size=1,
drop_last=False,
mode="TEST")
reader_test = test_reader.make_reader(
args, return_name=True)
......@@ -327,7 +372,7 @@ def infer(args):
fig = utility.plot(fake_image)
plt.savefig(
os.path.join(args.output, '/fake_dcgan.png'), bbox_inches='tight')
os.path.join(args.output, 'fake_dcgan.png'), bbox_inches='tight')
plt.close(fig)
else:
raise NotImplementedError("model_net {} is not support".format(
......
......@@ -172,7 +172,7 @@ def conv2d(input,
if padding_type == "SAME":
top_padding, bottom_padding = cal_padding(input.shape[2], stride,
filter_size)
left_padding, right_padding = cal_padding(input.shape[2], stride,
left_padding, right_padding = cal_padding(input.shape[3], stride,
filter_size)
height_padding = bottom_padding
width_padding = right_padding
......@@ -260,7 +260,7 @@ def deconv2d(input,
if padding_type == "SAME":
top_padding, bottom_padding = cal_padding(input.shape[2], stride,
filter_size)
left_padding, right_padding = cal_padding(input.shape[2], stride,
left_padding, right_padding = cal_padding(input.shape[3], stride,
filter_size)
height_padding = bottom_padding
width_padding = right_padding
......@@ -288,7 +288,7 @@ def deconv2d(input,
param_attr=param_attr,
bias_attr=bias_attr)
if outpadding != 0 and padding_type == None:
if np.mean(outpadding) != 0 and padding_type == None:
conv = fluid.layers.pad2d(
conv, paddings=outpadding, mode='constant', pad_value=0.0)
......
......@@ -38,24 +38,25 @@ def train(cfg):
reader = data_reader(cfg)
if cfg.model_net in ['CycleGAN']:
a_reader, b_reader, a_reader_test, b_reader_test, batch_num = reader.make_data(
a_reader, b_reader, a_reader_test, b_reader_test, batch_num, a_id2name, b_id2name = reader.make_data(
)
else:
if cfg.dataset in ['mnist']:
train_reader = reader.make_data()
else:
train_reader, test_reader, batch_num = reader.make_data()
train_reader, test_reader, batch_num, id2name = reader.make_data()
if cfg.model_net in ['CGAN', 'DCGAN']:
if cfg.dataset != 'mnist':
raise NotImplementedError("CGAN/DCGAN only support MNIST now!")
model = trainer.__dict__[cfg.model_net](cfg, train_reader)
elif cfg.model_net in ['CycleGAN']:
model = trainer.__dict__[cfg.model_net](
cfg, a_reader, b_reader, a_reader_test, b_reader_test, batch_num)
model = trainer.__dict__[cfg.model_net](cfg, a_reader, b_reader,
a_reader_test, b_reader_test,
batch_num, a_id2name, b_id2name)
else:
model = trainer.__dict__[cfg.model_net](cfg, train_reader, test_reader,
batch_num)
batch_num, id2name)
model.build_model()
......
......@@ -27,9 +27,8 @@ import six
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
#import imageio
from PIL import Image
import copy
from PIL import Image
img_dim = 28
......@@ -68,6 +67,66 @@ def init_checkpoints(cfg, exe, trainer, name):
sys.stdout.flush()
### the initialize checkpoint is one file named checkpoint.pdparams
def init_from_checkpoint(args, exe, trainer, name):
if not os.path.exists(args.init_model):
raise Warning("the checkpoint path does not exist.")
return False
fluid.io.load_persistables(
executor=exe,
dirname=os.path.join(args.init_model, name),
main_program=trainer.program,
filename="checkpoint.pdparams")
print("finish initing model from checkpoint from %s" % (args.init_model))
return True
### save the parameters of generator to one file
def save_param(args, exe, program, dirname, var_name="generator"):
param_dir = os.path.join(args.output, 'infer_vars')
if not os.path.exists(param_dir):
os.makedirs(param_dir)
def _name_has_generator(var):
res = (fluid.io.is_parameter(var) and var.name.startswith(var_name))
print(var.name, res)
return res
fluid.io.save_vars(
exe,
os.path.join(param_dir, dirname),
main_program=program,
predicate=_name_has_generator,
filename="params.pdparams")
print("save parameters at %s" % (os.path.join(param_dir, dirname)))
return True
### save the checkpoint to one file
def save_checkpoint(epoch, args, exe, trainer, dirname):
checkpoint_dir = os.path.join(args.output, 'checkpoints', str(epoch))
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
fluid.io.save_persistables(
exe,
os.path.join(checkpoint_dir, dirname),
main_program=trainer.program,
filename="checkpoint.pdparams")
print("save checkpoint at %s" % (os.path.join(checkpoint_dir, dirname)))
return True
def save_test_image(epoch,
cfg,
exe,
......@@ -75,42 +134,42 @@ def save_test_image(epoch,
test_program,
g_trainer,
A_test_reader,
B_test_reader=None):
B_test_reader=None,
A_id2name=None,
B_id2name=None):
out_path = os.path.join(cfg.output, 'test')
if not os.path.exists(out_path):
os.makedirs(out_path)
if cfg.model_net == "Pix2pix":
for data in zip(A_test_reader()):
data_A, data_B, name = data[0]
name = name[0]
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_A.set(data_A, place)
tensor_B.set(data_B, place)
fake_B_temp = exe.run(
test_program,
fetch_list=[g_trainer.fake_B],
feed={"input_A": tensor_A,
"input_B": tensor_B})
for data in A_test_reader():
A_data, B_data, image_name = data[0]['input_A'], data[0][
'input_B'], data[0]['image_name']
fake_B_temp = exe.run(test_program,
fetch_list=[g_trainer.fake_B],
feed={"input_A": A_data,
"input_B": B_data})
fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
input_A_temp = np.squeeze(data_A[0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_A[0]).transpose([1, 2, 0])
#imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + name, (
# (fake_B_temp + 1) * 127.5).astype(np.uint8))
#imageio.imwrite(out_path + "/inputA_" + str(epoch) + "_" + name, (
# (input_A_temp + 1) * 127.5).astype(np.uint8))
#imageio.imwrite(out_path + "/inputB_" + str(epoch) + "_" + name, (
# (input_B_temp + 1) * 127.5).astype(np.uint8))
input_A_temp = np.squeeze(np.array(A_data)[0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(np.array(A_data)[0]).transpose([1, 2, 0])
fakeB_name = "fakeB_" + str(epoch) + "_" + A_id2name[np.array(
image_name).astype('int32')[0]]
inputA_name = "inputA_" + str(epoch) + "_" + A_id2name[np.array(
image_name).astype('int32')[0]]
inputB_name = "inputB_" + str(epoch) + "_" + A_id2name[np.array(
image_name).astype('int32')[0]]
res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
np.uint8))
res_fakeB.save(os.path.join(out_path+"/fakeB_" + str(epoch) + "_", name))
res_fakeB.save(os.path.join(out_path, fakeB_name))
res_inputA = Image.fromarray(((input_A_temp + 1) * 127.5).astype(
np.uint8))
res_inputA.save(os.path.join(out_path+"/inputA_" + str(epoch) + "_", name))
res_inputA.save(os.path.join(out_path, inputA_name))
res_inputB = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
np.uint8))
res_inputB.save(os.path.join(out_path+"/inputB_" + str(epoch) + "_", name))
res_inputB.save(os.path.join(out_path, inputB_name))
elif cfg.model_net == "SPADE":
for data in zip(A_test_reader()):
data_A, data_B, data_C, name = data[0]
......@@ -128,45 +187,35 @@ def save_test_image(epoch,
"input_img": tensor_B,
"input_ins": tensor_C})
fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
#input_A_temp = np.squeeze(data_A[0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0])
#imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + name, (
# (fake_B_temp + 1) * 127.5).astype(np.uint8))
#imageio.imwrite(out_path + "/real_" + str(epoch) + "_" + name, (
# (input_B_temp + 1) * 127.5).astype(np.uint8))
res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
np.uint8))
#res_fakeB.save(os.path.join(out_path+"/fakeB_"+str(epoch)+"_", name))
res_fakeB.save(out_path+"/fakeB_"+str(epoch)+"_"+name)
res_real = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
np.uint8))
#res_real.save(os.path.join(out_path+"/real_"+str(epoch)+"_", name))
res_real.save(out_path+"/real_"+str(epoch)+"_"+name)
elif cfg.model_net == "StarGAN":
for data in zip(A_test_reader()):
real_img, label_org, name = data[0]
for data in A_test_reader():
real_img, label_org, label_trg, image_name = data[0][
'image_real'], data[0]['label_org'], data[0]['label_trg'], data[
0]['image_name']
attr_names = cfg.selected_attrs.split(',')
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = save_batch_image(real_img)
real_img_temp = save_batch_image(np.array(real_img))
images = [real_img_temp]
for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_org)
for j in range(len(label_org)):
label_trg_tmp = copy.deepcopy(np.array(label_org))
for j in range(len(np.array(label_org))):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg = check_attribute_conflict(
np_label_trg = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
tensor_label_trg = fluid.LoDTensor()
tensor_label_trg.set(label_trg, place)
label_trg.set(np_label_trg, place)
fake_temp, rec_temp = exe.run(
test_program,
feed={
"image_real": tensor_img,
"label_org": tensor_label_org,
"label_trg": tensor_label_trg
"image_real": real_img,
"label_org": label_org,
"label_trg": label_trg
},
fetch_list=[g_trainer.fake_img, g_trainer.rec_img])
fake_temp = save_batch_image(fake_temp)
......@@ -174,102 +223,120 @@ def save_test_image(epoch,
images.append(fake_temp)
images.append(rec_temp)
images_concat = np.concatenate(images, 1)
if len(label_org) > 1:
if len(np.array(label_org)) > 1:
images_concat = np.concatenate(images_concat, 1)
#imageio.imwrite(out_path + "/fake_img" + str(epoch) + "_" + name[0],
# ((images_concat + 1) * 127.5).astype(np.uint8))
image_name_save = "fake_img" + str(epoch) + "_" + str(
np.array(image_name)[0].astype('int32')) + '.jpg'
res = Image.fromarray(((images_concat + 1) * 127.5).astype(
np.uint8))
res.save(os.path.join(out_path+"/fake_img" + str(epoch) + "_", name[0]))
res.save(os.path.join(out_path, image_name_save))
elif cfg.model_net == 'AttGAN' or cfg.model_net == 'STGAN':
for data in zip(A_test_reader()):
real_img, label_org, name = data[0]
for data in A_test_reader():
real_img, label_org, label_trg, image_name = data[0][
'image_real'], data[0]['label_org'], data[0]['label_trg'], data[
0]['image_name']
attr_names = cfg.selected_attrs.split(',')
label_trg = copy.deepcopy(label_org)
tensor_img = fluid.LoDTensor()
tensor_label_org = fluid.LoDTensor()
tensor_label_trg = fluid.LoDTensor()
tensor_label_org_ = fluid.LoDTensor()
tensor_label_trg_ = fluid.LoDTensor()
tensor_img.set(real_img, place)
tensor_label_org.set(label_org, place)
real_img_temp = save_batch_image(real_img)
real_img_temp = save_batch_image(np.array(real_img))
images = [real_img_temp]
for i in range(cfg.c_dim):
label_trg_tmp = copy.deepcopy(label_trg)
label_trg_tmp = copy.deepcopy(np.array(label_trg))
for j in range(len(label_org)):
for j in range(len(label_trg_tmp)):
label_trg_tmp[j][i] = 1.0 - label_trg_tmp[j][i]
label_trg_tmp = check_attribute_conflict(
label_trg_tmp, attr_names[i], attr_names)
label_org_ = list(map(lambda x: ((x * 2) - 1) * 0.5, label_org))
label_trg_ = list(
label_org_tmp = list(
map(lambda x: ((x * 2) - 1) * 0.5, np.array(label_org)))
label_trg_tmp = list(
map(lambda x: ((x * 2) - 1) * 0.5, label_trg_tmp))
if cfg.model_net == 'AttGAN':
for k in range(len(label_org)):
label_trg_[k][i] = label_trg_[k][i] * 2.0
tensor_label_org_.set(label_org_, place)
tensor_label_trg.set(label_trg, place)
tensor_label_trg_.set(label_trg_, place)
for k in range(len(label_trg_tmp)):
label_trg_tmp[k][i] = label_trg_tmp[k][i] * 2.0
tensor_label_org_ = fluid.LoDTensor()
tensor_label_org_.set(label_org_tmp, place)
tensor_label_trg_ = fluid.LoDTensor()
tensor_label_trg_.set(label_trg_tmp, place)
out = exe.run(test_program,
feed={
"image_real": tensor_img,
"label_org": tensor_label_org,
"image_real": real_img,
"label_org": label_org,
"label_org_": tensor_label_org_,
"label_trg": tensor_label_trg,
"label_trg": label_trg,
"label_trg_": tensor_label_trg_
},
fetch_list=[g_trainer.fake_img])
fake_temp = save_batch_image(out[0])
images.append(fake_temp)
images_concat = np.concatenate(images, 1)
if len(label_org) > 1:
if len(label_trg_tmp) > 1:
images_concat = np.concatenate(images_concat, 1)
#imageio.imwrite(out_path + "/fake_img" + str(epoch) + '_' + name[0],
# ((images_concat + 1) * 127.5).astype(np.uint8))
image_name_save = 'fake_img_' + str(epoch) + '_' + str(
np.array(image_name)[0].astype('int32')) + '.jpg'
res = Image.fromarray(((images_concat + 1) * 127.5).astype(
np.uint8))
res.save(os.path.join(out_path+"/fake_img" + str(epoch) + '_', name[0]))
res.save(os.path.join(out_path, image_name_save))
else:
for data_A, data_B in zip(A_test_reader(), B_test_reader()):
A_data, A_name = data_A
B_data, B_name = data_B
tensor_A = fluid.LoDTensor()
tensor_B = fluid.LoDTensor()
tensor_A.set(A_data, place)
tensor_B.set(B_data, place)
A_data, A_name = data_A[0]['input_A'], data_A[0]['A_image_name']
B_data, B_name = data_B[0]['input_B'], data_B[0]['B_image_name']
fake_A_temp, fake_B_temp, cyc_A_temp, cyc_B_temp = exe.run(
test_program,
fetch_list=[
g_trainer.fake_A, g_trainer.fake_B, g_trainer.cyc_A,
g_trainer.cyc_B
],
feed={"input_A": tensor_A,
"input_B": tensor_B})
feed={"input_A": A_data,
"input_B": B_data})
fake_A_temp = np.squeeze(fake_A_temp[0]).transpose([1, 2, 0])
fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
cyc_A_temp = np.squeeze(cyc_A_temp[0]).transpose([1, 2, 0])
cyc_B_temp = np.squeeze(cyc_B_temp[0]).transpose([1, 2, 0])
input_A_temp = np.squeeze(data_A[0][0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0][0]).transpose([1, 2, 0])
imageio.imwrite(out_path + "/fakeB_" + str(epoch) + "_" + A_name[0],
((fake_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/fakeA_" + str(epoch) + "_" + B_name[0],
((fake_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycA_" + str(epoch) + "_" + A_name[0],
((cyc_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(out_path + "/cycB_" + str(epoch) + "_" + B_name[0],
((cyc_B_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
out_path + "/inputA_" + str(epoch) + "_" + A_name[0], (
(input_A_temp + 1) * 127.5).astype(np.uint8))
imageio.imwrite(
out_path + "/inputB_" + str(epoch) + "_" + B_name[0], (
(input_B_temp + 1) * 127.5).astype(np.uint8))
input_A_temp = np.squeeze(np.array(A_data)).transpose([1, 2, 0])
input_B_temp = np.squeeze(np.array(B_data)).transpose([1, 2, 0])
fakeA_name = "fakeA_" + str(epoch) + "_" + A_id2name[np.array(
A_name).astype('int32')[0]]
fakeB_name = "fakeB_" + str(epoch) + "_" + B_id2name[np.array(
B_name).astype('int32')[0]]
inputA_name = "inputA_" + str(epoch) + "_" + A_id2name[np.array(
A_name).astype('int32')[0]]
inputB_name = "inputB_" + str(epoch) + "_" + B_id2name[np.array(
B_name).astype('int32')[0]]
cycA_name = "cycA_" + str(epoch) + "_" + A_id2name[np.array(
A_name).astype('int32')[0]]
cycB_name = "cycB_" + str(epoch) + "_" + B_id2name[np.array(
B_name).astype('int32')[0]]
res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
np.uint8))
res_fakeB.save(os.path.join(out_path, fakeB_name))
res_fakeA = Image.fromarray(((fake_A_temp + 1) * 127.5).astype(
np.uint8))
res_fakeA.save(os.path.join(out_path, fakeA_name))
res_cycA = Image.fromarray(((cyc_A_temp + 1) * 127.5).astype(
np.uint8))
res_cycA.save(os.path.join(out_path, cycA_name))
res_cycB = Image.fromarray(((cyc_B_temp + 1) * 127.5).astype(
np.uint8))
res_cycB.save(os.path.join(out_path, cycB_name))
res_inputA = Image.fromarray(((input_A_temp + 1) * 127.5).astype(
np.uint8))
res_inputA.save(os.path.join(out_path, inputA_name))
res_inputB = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
np.uint8))
res_inputB.save(os.path.join(out_path, inputB_name))
class ImagePool(object):
......@@ -321,6 +388,7 @@ def check_attribute_conflict(label_batch, attr, attrs):
def save_batch_image(img):
#if img.shape[0] == 1:
if len(img) == 1:
res_img = np.squeeze(img).transpose([1, 2, 0])
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册