未验证 提交 88d125f2 编写于 作者: L lvmengsi 提交者: GitHub

fix spade (#3387)

上级 32ae4f2b
......@@ -237,6 +237,9 @@ class triplex_reader_creator(reader_creator):
batch_size=batch_size,
mode=mode)
self.name2id = {}
self.id2name = {}
def make_reader(self, args, return_name=False):
print(self.image_dir, self.list_filename)
print("files length:", len(self.lines))
......@@ -248,11 +251,13 @@ class triplex_reader_creator(reader_creator):
batch_out_name = []
if self.shuffle:
np.random.shuffle(self.lines)
for line in self.lines:
for i, line in enumerate(self.lines):
files = line.strip('\n\r\t ').split('\t')
if len(files) != 3:
print("files is not equal to 3!")
sys.exit(-1)
self.name2id[os.path.basename(files[0])] = i
self.id2name[i] = os.path.basename(files[0])
#label image instance
img1 = Image.open(os.path.join(self.image_dir, files[0]))
img2 = Image.open(os.path.join(self.image_dir, files[
......@@ -327,7 +332,7 @@ class triplex_reader_creator(reader_creator):
if not args.no_instance:
batch_out_3.append(img3)
if return_name:
batch_out_name.append(os.path.basename(files[0]))
batch_out_name.append(i)
if len(batch_out_1) == self.batch_size:
if return_name:
if not args.no_instance:
......
......@@ -274,11 +274,13 @@ class SPADE(object):
cfg=None,
train_reader=None,
test_reader=None,
batch_num=1):
batch_num=1,
id2name=None):
self.cfg = cfg
self.train_reader = train_reader
self.test_reader = test_reader
self.batch_num = batch_num
self.id2name = id2name
def build_model(self):
data_shape = [-1, 3, self.cfg.crop_height, self.cfg.crop_width]
......@@ -324,7 +326,7 @@ class SPADE(object):
### memory optim
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.sync_batch_norm = True
build_strategy.sync_batch_norm = False
gen_trainer_program = fluid.CompiledProgram(
gen_trainer.program).with_data_parallel(
......@@ -340,8 +342,8 @@ class SPADE(object):
for epoch_id in range(self.cfg.epoch):
batch_id = 0
for tensor in py_reader():
data_A, data_B, data_C = tensor[0]['input_A'], tensor[0][
'input_B'], tensor[0]['input_C']
data_A, data_B, data_C = tensor[0]['input_label'], tensor[0][
'input_img'], tensor[0]['input_ins']
s_time = time.time()
# optimize the generator network
g_loss_gan, g_loss_vgg, g_loss_feat, fake_B_tmp = exe.run(
......@@ -390,7 +392,7 @@ class SPADE(object):
shape=[self.cfg.batch_size],
dtype="int32")
test_py_reader = fluid.io.PyReader(
feed_list=[input_A, input_B, image_name],
feed_list=[input_A, input_B, input_C, image_name],
capacity=4, ## batch_size * 4
iterable=True,
use_double_buffer=True)
......@@ -398,9 +400,15 @@ class SPADE(object):
self.test_reader,
places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places())
utility.save_test_image(epoch_id, self.cfg, exe, place,
test_program, gen_trainer,
test_py_reader)
utility.save_test_image(
epoch_id,
self.cfg,
exe,
place,
test_program,
gen_trainer,
test_py_reader,
A_id2name=self.id2name)
if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer,
......
......@@ -19,6 +19,7 @@ from .Pix2pix import Pix2pix
from .STGAN import STGAN
from .StarGAN import StarGAN
from .AttGAN import AttGAN
from .SPADE import SPADE
import importlib
......
......@@ -172,8 +172,8 @@ def save_test_image(epoch,
res_inputB.save(os.path.join(out_path, inputB_name))
elif cfg.model_net == "SPADE":
for data in A_test_reader():
data_A, data_B, data_C, name = data[0]['input_A'], data[0][
'input_B'], data[0]['input_C'], data[0]['image_name']
data_A, data_B, data_C, name = data[0]['input_label'], data[0][
'input_img'], data[0]['input_ins'], data[0]['image_name']
fake_B_temp = exe.run(test_program,
fetch_list=[g_trainer.fake_B],
feed={
......@@ -183,13 +183,14 @@ def save_test_image(epoch,
})
fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0])
image_name = A_id2name[np.array(name).astype('int32')[0]]
res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
np.uint8))
res_fakeB.save(out_path + "/fakeB_" + str(epoch) + "_" + name)
res_fakeB.save(out_path + "/fakeB_" + str(epoch) + "_" + image_name)
res_real = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
np.uint8))
res_real.save(out_path + "/real_" + str(epoch) + "_" + name)
res_real.save(out_path + "/real_" + str(epoch) + "_" + image_name)
elif cfg.model_net == "StarGAN":
for data in A_test_reader():
real_img, label_org, label_trg, image_name = data[0][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册