diff --git a/PaddleCV/gan/data_reader.py b/PaddleCV/gan/data_reader.py index 407855abca1c3328e931841f404a3f80a9b6cc36..6e8a65f65004c4d76ef955e45eaa97fac6a3c77c 100644 --- a/PaddleCV/gan/data_reader.py +++ b/PaddleCV/gan/data_reader.py @@ -99,16 +99,24 @@ class reader_creator(object): def make_reader(self, args, return_name=False): print(self.image_dir, self.list_filename) + self.with_label = False def reader(): batch_out = [] + batch_out_label = [] batch_out_name = [] if self.shuffle: np.random.shuffle(self.lines) - for i, file in enumerate(self.lines): - file = file.strip('\n\r\t ') + for i, line in enumerate(self.lines): + line = line.strip('\n\r\t').split(' ') + if len(line) > 1: + self.with_label = True + batch_out_label.append(line[1]) + file = line[0] + else: + file = line[0] self.name2id[os.path.basename(file)] = i self.id2name[i] = os.path.basename(file) img = Image.open(os.path.join(self.image_dir, file)).convert( @@ -133,10 +141,18 @@ class reader_creator(object): batch_out.append(img) if len(batch_out) == self.batch_size: if return_name: - yield batch_out, batch_out_name + if self.with_label: + yield [[batch_out, batch_out_label, batch_out_name]] + batch_out_label = [] + else: + yield batch_out, batch_out_name batch_out_name = [] else: - yield [batch_out] + if self.with_label: + yield [[batch_out, batch_out_label]] + batch_out_label = [] + else: + yield [batch_out] batch_out = [] return reader @@ -667,8 +683,9 @@ class data_reader(object): image_dir=dataset_dir, list_filename=test_list, batch_size=self.cfg.n_samples) - reader_test = test_reader.get_test_reader( + reader_test = test_reader.make_reader( self.cfg, shuffle=False, return_name=True) id2name = test_reader.id2name batch_num = train_reader.len() - return train_reader, reader_test, batch_num, id2name + reader = train_reader.make_reader(self.cfg) + return reader, reader_test, batch_num, id2name diff --git a/dygraph/cycle_gan/infer.py b/dygraph/cycle_gan/infer.py index de70585f621ae51f1a88da3a1581d1f4266aecd8..b802ace1a86d132ea6acd684373c6bc5dc4cf287 100644 --- a/dygraph/cycle_gan/infer.py +++ b/dygraph/cycle_gan/infer.py @@ -50,7 +50,7 @@ def infer(): out_path = args.output + "/single" if not os.path.exists(out_path): os.makedirs(out_path) - cycle_gan = Cycle_Gan("cycle_gan") + cycle_gan = Cycle_Gan(3) save_dir = args.init_model restore, _ = fluid.load_dygraph(save_dir) cycle_gan.set_dict(restore) diff --git a/dygraph/cycle_gan/test.py b/dygraph/cycle_gan/test.py index ba0b03ba2045102c8acab88a33eb16932db0e9fe..163ae4e9148ce22c207909e362842f731de71d0d 100644 --- a/dygraph/cycle_gan/test.py +++ b/dygraph/cycle_gan/test.py @@ -50,7 +50,7 @@ def test(): out_path = args.output + "/eval" + "/" + str(epoch) if not os.path.exists(out_path): os.makedirs(out_path) - cycle_gan = Cycle_Gan("cycle_gan") + cycle_gan = Cycle_Gan(3) save_dir = args.init_model + str(epoch) restore, _ = fluid.load_dygraph(save_dir) cycle_gan.set_dict(restore) diff --git a/dygraph/cycle_gan/train.py b/dygraph/cycle_gan/train.py index a1422047b0d02f5e6cd9dfaa97e5840d38a7bf69..6a4c43821d16ad87621e630fb44e88cb8c5c6a8a 100644 --- a/dygraph/cycle_gan/train.py +++ b/dygraph/cycle_gan/train.py @@ -44,7 +44,7 @@ add_arg('save_checkpoints', bool, True, "Whether to save checkpoints.") lambda_A = 10.0 lambda_B = 10.0 lambda_identity = 0.5 -tep_per_epoch = 2974 +step_per_epoch = 2974 def optimizer_setting(parameters): @@ -90,7 +90,8 @@ def train(args): losses = [[], []] t_time = 0 - vars_G = cycle_gan.build_generator_resnet_9blocks_a.parameters() + cycle_gan.build_generator_resnet_9blocks_b.parameters() + vars_G = cycle_gan.build_generator_resnet_9blocks_a.parameters( + ) + cycle_gan.build_generator_resnet_9blocks_b.parameters() vars_da = cycle_gan.build_gen_discriminator_a.parameters() vars_db = cycle_gan.build_gen_discriminator_b.parameters()