未验证 提交 a5c2db94 编写于 作者: C ceci3 提交者: GitHub

Fix dygraph cyclegan (#4592)

* fix init

* update

* fix
上级 ccdbfe77
......@@ -99,16 +99,24 @@ class reader_creator(object):
def make_reader(self, args, return_name=False):
print(self.image_dir, self.list_filename)
self.with_label = False
def reader():
batch_out = []
batch_out_label = []
batch_out_name = []
if self.shuffle:
np.random.shuffle(self.lines)
for i, file in enumerate(self.lines):
file = file.strip('\n\r\t ')
for i, line in enumerate(self.lines):
line = line.strip('\n\r\t').split(' ')
if len(line) > 1:
self.with_label = True
batch_out_label.append(line[1])
file = line[0]
else:
file = line[0]
self.name2id[os.path.basename(file)] = i
self.id2name[i] = os.path.basename(file)
img = Image.open(os.path.join(self.image_dir, file)).convert(
......@@ -133,10 +141,18 @@ class reader_creator(object):
batch_out.append(img)
if len(batch_out) == self.batch_size:
if return_name:
yield batch_out, batch_out_name
if self.with_label:
yield [[batch_out, batch_out_label, batch_out_name]]
batch_out_label = []
else:
yield batch_out, batch_out_name
batch_out_name = []
else:
yield [batch_out]
if self.with_label:
yield [[batch_out, batch_out_label]]
batch_out_label = []
else:
yield [batch_out]
batch_out = []
return reader
......@@ -667,8 +683,9 @@ class data_reader(object):
image_dir=dataset_dir,
list_filename=test_list,
batch_size=self.cfg.n_samples)
reader_test = test_reader.get_test_reader(
reader_test = test_reader.make_reader(
self.cfg, shuffle=False, return_name=True)
id2name = test_reader.id2name
batch_num = train_reader.len()
return train_reader, reader_test, batch_num, id2name
reader = train_reader.make_reader(self.cfg)
return reader, reader_test, batch_num, id2name
......@@ -50,7 +50,7 @@ def infer():
out_path = args.output + "/single"
if not os.path.exists(out_path):
os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan")
cycle_gan = Cycle_Gan(3)
save_dir = args.init_model
restore, _ = fluid.load_dygraph(save_dir)
cycle_gan.set_dict(restore)
......
......@@ -50,7 +50,7 @@ def test():
out_path = args.output + "/eval" + "/" + str(epoch)
if not os.path.exists(out_path):
os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan")
cycle_gan = Cycle_Gan(3)
save_dir = args.init_model + str(epoch)
restore, _ = fluid.load_dygraph(save_dir)
cycle_gan.set_dict(restore)
......
......@@ -44,7 +44,7 @@ add_arg('save_checkpoints', bool, True, "Whether to save checkpoints.")
lambda_A = 10.0
lambda_B = 10.0
lambda_identity = 0.5
tep_per_epoch = 2974
step_per_epoch = 2974
def optimizer_setting(parameters):
......@@ -90,7 +90,8 @@ def train(args):
losses = [[], []]
t_time = 0
vars_G = cycle_gan.build_generator_resnet_9blocks_a.parameters() + cycle_gan.build_generator_resnet_9blocks_b.parameters()
vars_G = cycle_gan.build_generator_resnet_9blocks_a.parameters(
) + cycle_gan.build_generator_resnet_9blocks_b.parameters()
vars_da = cycle_gan.build_gen_discriminator_a.parameters()
vars_db = cycle_gan.build_gen_discriminator_b.parameters()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册