未验证 提交 88d125f2 编写于 作者: L lvmengsi 提交者: GitHub

fix spade (#3387)

上级 32ae4f2b
...@@ -237,6 +237,9 @@ class triplex_reader_creator(reader_creator): ...@@ -237,6 +237,9 @@ class triplex_reader_creator(reader_creator):
batch_size=batch_size, batch_size=batch_size,
mode=mode) mode=mode)
self.name2id = {}
self.id2name = {}
def make_reader(self, args, return_name=False): def make_reader(self, args, return_name=False):
print(self.image_dir, self.list_filename) print(self.image_dir, self.list_filename)
print("files length:", len(self.lines)) print("files length:", len(self.lines))
...@@ -248,11 +251,13 @@ class triplex_reader_creator(reader_creator): ...@@ -248,11 +251,13 @@ class triplex_reader_creator(reader_creator):
batch_out_name = [] batch_out_name = []
if self.shuffle: if self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
for line in self.lines: for i, line in enumerate(self.lines):
files = line.strip('\n\r\t ').split('\t') files = line.strip('\n\r\t ').split('\t')
if len(files) != 3: if len(files) != 3:
print("files is not equal to 3!") print("files is not equal to 3!")
sys.exit(-1) sys.exit(-1)
self.name2id[os.path.basename(files[0])] = i
self.id2name[i] = os.path.basename(files[0])
#label image instance #label image instance
img1 = Image.open(os.path.join(self.image_dir, files[0])) img1 = Image.open(os.path.join(self.image_dir, files[0]))
img2 = Image.open(os.path.join(self.image_dir, files[ img2 = Image.open(os.path.join(self.image_dir, files[
...@@ -327,7 +332,7 @@ class triplex_reader_creator(reader_creator): ...@@ -327,7 +332,7 @@ class triplex_reader_creator(reader_creator):
if not args.no_instance: if not args.no_instance:
batch_out_3.append(img3) batch_out_3.append(img3)
if return_name: if return_name:
batch_out_name.append(os.path.basename(files[0])) batch_out_name.append(i)
if len(batch_out_1) == self.batch_size: if len(batch_out_1) == self.batch_size:
if return_name: if return_name:
if not args.no_instance: if not args.no_instance:
......
...@@ -274,11 +274,13 @@ class SPADE(object): ...@@ -274,11 +274,13 @@ class SPADE(object):
cfg=None, cfg=None,
train_reader=None, train_reader=None,
test_reader=None, test_reader=None,
batch_num=1): batch_num=1,
id2name=None):
self.cfg = cfg self.cfg = cfg
self.train_reader = train_reader self.train_reader = train_reader
self.test_reader = test_reader self.test_reader = test_reader
self.batch_num = batch_num self.batch_num = batch_num
self.id2name = id2name
def build_model(self): def build_model(self):
data_shape = [-1, 3, self.cfg.crop_height, self.cfg.crop_width] data_shape = [-1, 3, self.cfg.crop_height, self.cfg.crop_width]
...@@ -324,7 +326,7 @@ class SPADE(object): ...@@ -324,7 +326,7 @@ class SPADE(object):
### memory optim ### memory optim
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy.sync_batch_norm = True build_strategy.sync_batch_norm = False
gen_trainer_program = fluid.CompiledProgram( gen_trainer_program = fluid.CompiledProgram(
gen_trainer.program).with_data_parallel( gen_trainer.program).with_data_parallel(
...@@ -340,8 +342,8 @@ class SPADE(object): ...@@ -340,8 +342,8 @@ class SPADE(object):
for epoch_id in range(self.cfg.epoch): for epoch_id in range(self.cfg.epoch):
batch_id = 0 batch_id = 0
for tensor in py_reader(): for tensor in py_reader():
data_A, data_B, data_C = tensor[0]['input_A'], tensor[0][ data_A, data_B, data_C = tensor[0]['input_label'], tensor[0][
'input_B'], tensor[0]['input_C'] 'input_img'], tensor[0]['input_ins']
s_time = time.time() s_time = time.time()
# optimize the generator network # optimize the generator network
g_loss_gan, g_loss_vgg, g_loss_feat, fake_B_tmp = exe.run( g_loss_gan, g_loss_vgg, g_loss_feat, fake_B_tmp = exe.run(
...@@ -390,7 +392,7 @@ class SPADE(object): ...@@ -390,7 +392,7 @@ class SPADE(object):
shape=[self.cfg.batch_size], shape=[self.cfg.batch_size],
dtype="int32") dtype="int32")
test_py_reader = fluid.io.PyReader( test_py_reader = fluid.io.PyReader(
feed_list=[input_A, input_B, image_name], feed_list=[input_A, input_B, input_C, image_name],
capacity=4, ## batch_size * 4 capacity=4, ## batch_size * 4
iterable=True, iterable=True,
use_double_buffer=True) use_double_buffer=True)
...@@ -398,9 +400,15 @@ class SPADE(object): ...@@ -398,9 +400,15 @@ class SPADE(object):
self.test_reader, self.test_reader,
places=fluid.cuda_places() places=fluid.cuda_places()
if self.cfg.use_gpu else fluid.cpu_places()) if self.cfg.use_gpu else fluid.cpu_places())
utility.save_test_image(epoch_id, self.cfg, exe, place, utility.save_test_image(
test_program, gen_trainer, epoch_id,
test_py_reader) self.cfg,
exe,
place,
test_program,
gen_trainer,
test_py_reader,
A_id2name=self.id2name)
if self.cfg.save_checkpoints: if self.cfg.save_checkpoints:
utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer, utility.checkpoints(epoch_id, self.cfg, exe, gen_trainer,
......
...@@ -19,6 +19,7 @@ from .Pix2pix import Pix2pix ...@@ -19,6 +19,7 @@ from .Pix2pix import Pix2pix
from .STGAN import STGAN from .STGAN import STGAN
from .StarGAN import StarGAN from .StarGAN import StarGAN
from .AttGAN import AttGAN from .AttGAN import AttGAN
from .SPADE import SPADE
import importlib import importlib
......
...@@ -172,8 +172,8 @@ def save_test_image(epoch, ...@@ -172,8 +172,8 @@ def save_test_image(epoch,
res_inputB.save(os.path.join(out_path, inputB_name)) res_inputB.save(os.path.join(out_path, inputB_name))
elif cfg.model_net == "SPADE": elif cfg.model_net == "SPADE":
for data in A_test_reader(): for data in A_test_reader():
data_A, data_B, data_C, name = data[0]['input_A'], data[0][ data_A, data_B, data_C, name = data[0]['input_label'], data[0][
'input_B'], data[0]['input_C'], data[0]['image_name'] 'input_img'], data[0]['input_ins'], data[0]['image_name']
fake_B_temp = exe.run(test_program, fake_B_temp = exe.run(test_program,
fetch_list=[g_trainer.fake_B], fetch_list=[g_trainer.fake_B],
feed={ feed={
...@@ -183,13 +183,14 @@ def save_test_image(epoch, ...@@ -183,13 +183,14 @@ def save_test_image(epoch,
}) })
fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0]) fake_B_temp = np.squeeze(fake_B_temp[0]).transpose([1, 2, 0])
input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0]) input_B_temp = np.squeeze(data_B[0]).transpose([1, 2, 0])
image_name = A_id2name[np.array(name).astype('int32')[0]]
res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype( res_fakeB = Image.fromarray(((fake_B_temp + 1) * 127.5).astype(
np.uint8)) np.uint8))
res_fakeB.save(out_path + "/fakeB_" + str(epoch) + "_" + name) res_fakeB.save(out_path + "/fakeB_" + str(epoch) + "_" + image_name)
res_real = Image.fromarray(((input_B_temp + 1) * 127.5).astype( res_real = Image.fromarray(((input_B_temp + 1) * 127.5).astype(
np.uint8)) np.uint8))
res_real.save(out_path + "/real_" + str(epoch) + "_" + name) res_real.save(out_path + "/real_" + str(epoch) + "_" + image_name)
elif cfg.model_net == "StarGAN": elif cfg.model_net == "StarGAN":
for data in A_test_reader(): for data in A_test_reader():
real_img, label_org, label_trg, image_name = data[0][ real_img, label_org, label_trg, image_name = data[0][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册