diff --git a/configs/cyclegan_cityscapes.yaml b/configs/cyclegan_cityscapes.yaml index e4e40208e876be1cf4ea88878eded500724d0343..4f5257bb6108d2c7191e84356b49e21796790704 100644 --- a/configs/cyclegan_cityscapes.yaml +++ b/configs/cyclegan_cityscapes.yaml @@ -67,7 +67,7 @@ dataset: batch_size: 1 max_size: inf is_train: False - load_pipeline: + preprocess: - name: LoadImageFromFile key: A - name: LoadImageFromFile diff --git a/configs/cyclegan_horse2zebra.yaml b/configs/cyclegan_horse2zebra.yaml index 924c8e4b284377084eec3ea5d5c456c4cf093398..15e449e7607791794cdc9448025f0c62d6f041b1 100644 --- a/configs/cyclegan_horse2zebra.yaml +++ b/configs/cyclegan_horse2zebra.yaml @@ -35,7 +35,7 @@ dataset: batch_size: 1 is_train: True max_size: inf - load_pipeline: + preprocess: - name: LoadImageFromFile key: A - name: LoadImageFromFile @@ -67,7 +67,7 @@ dataset: batch_size: 1 max_size: inf is_train: False - load_pipeline: + preprocess: - name: LoadImageFromFile key: A - name: LoadImageFromFile diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index 0a3fc63ef810faf0e3f3ec6cae93257b400b9ca1..5a24315b60f8547d4ecf14a708554b4c05d78680 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -61,7 +61,7 @@ dataset: dataroot: data/cityscapes/test num_workers: 4 batch_size: 1 - load_pipeline: + preprocess: - name: LoadImageFromFile key: pair - name: SplitPairedImage diff --git a/configs/pix2pix_cityscapes_2gpus.yaml b/configs/pix2pix_cityscapes_2gpus.yaml index 32bf846c3a3fc435830eee374a845c434e6f4e49..6686cfedca2efdc0a55791c828b27f107cd814a5 100644 --- a/configs/pix2pix_cityscapes_2gpus.yaml +++ b/configs/pix2pix_cityscapes_2gpus.yaml @@ -61,7 +61,7 @@ dataset: dataroot: data/cityscapes/test num_workers: 4 batch_size: 1 - load_pipeline: + preprocess: - name: LoadImageFromFile key: pair - name: Transforms diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index 0e044f9cd037850d1d52beaf1870029a80d86630..c7ed53754992052ecf708be7ed2346d6ba070e3a 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -61,7 +61,7 @@ dataset: dataroot: data/facades/test num_workers: 4 batch_size: 1 - load_pipeline: + preprocess: - name: LoadImageFromFile key: pair - name: Transforms diff --git a/ppgan/apps/realsr_predictor.py b/ppgan/apps/realsr_predictor.py index 759d8107f480e5e01f0466a0ecfcaf1cecf29524..19139a88e2020b8503ad0070a871ac28777900c0 100644 --- a/ppgan/apps/realsr_predictor.py +++ b/ppgan/apps/realsr_predictor.py @@ -60,7 +60,8 @@ class RealSRPredictor(BasePredictor): img = self.norm(ori_img) x = paddle.to_tensor(img[np.newaxis, ...]) - out = self.model(x) + with paddle.no_grad(): + out = self.model(x) pred_img = self.denorm(out.numpy()[0]) pred_img = Image.fromarray(pred_img) diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 2db728c0a2ead372f3e828342c8f1b9b8285a15d..ab75cb0b04d9c554fbbc4d2cbbacc7b552a4d9fe 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -124,6 +124,9 @@ class Trainer: self.weight_interval = cfg.snapshot_config.interval self.log_interval = cfg.log_config.interval self.visual_interval = cfg.log_config.visiual_interval + if self.by_epoch: + self.weight_interval *= self.iters_per_epoch + self.validate_interval = -1 if cfg.get('validate', None) is not None: self.validate_interval = cfg.validate.get('interval', -1) @@ -177,16 +180,12 @@ class Trainer: self.model.lr_scheduler.step() - if self.by_epoch: - temp = self.current_epoch - else: - temp = self.current_iter - if self.validate_interval > -1 and temp % self.validate_interval == 0: + if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0: self.test() - if temp % self.weight_interval == 0: - self.save(temp, 'weight', keep=-1) - self.save(temp) + if self.current_iter % self.weight_interval == 0: + self.save(self.current_iter, 'weight', keep=-1) + self.save(self.current_iter) self.current_iter += 1 @@ -335,7 +334,12 @@ class Trainer: assert name in ['checkpoint', 'weight'] state_dicts = {} - save_filename = 'epoch_%s_%s.pdparams' % (epoch, name) + if self.by_epoch: + save_filename = 'epoch_%s_%s.pdparams' % ( + epoch // self.iters_per_epoch, name) + else: + save_filename = 'iter_%s_%s.pdparams' % (epoch, name) + save_path = os.path.join(self.output_dir, save_filename) for net_name, net in self.model.nets.items(): state_dicts[net_name] = net.state_dict() @@ -353,9 +357,16 @@ class Trainer: if keep > 0: try: - checkpoint_name_to_be_removed = os.path.join( - self.output_dir, - 'epoch_%s_%s.pdparams' % (epoch - keep, name)) + if self.by_epoch: + checkpoint_name_to_be_removed = os.path.join( + self.output_dir, 'epoch_%s_%s.pdparams' % + ((epoch - keep * self.weight_interval) // + self.iters_per_epoch, name)) + else: + checkpoint_name_to_be_removed = os.path.join( + self.output_dir, 'iter_%s_%s.pdparams' % + (epoch - keep * self.weight_interval, name)) + if os.path.exists(checkpoint_name_to_be_removed): os.remove(checkpoint_name_to_be_removed) @@ -366,7 +377,7 @@ class Trainer: state_dicts = load(checkpoint_path) if state_dicts.get('epoch', None) is not None: self.start_epoch = state_dicts['epoch'] + 1 - self.global_steps = self.steps_per_epoch * state_dicts['epoch'] + self.global_steps = self.iters_per_epoch * state_dicts['epoch'] for net_name, net in self.model.nets.items(): net.set_state_dict(state_dicts[net_name]) diff --git a/ppgan/utils/options.py b/ppgan/utils/options.py index 9aff1618e2ebefb45c1860445bb74dc10a01ac04..e87477af0ca867e1d53d0d64ae9249e9ca2f1a1d 100644 --- a/ppgan/utils/options.py +++ b/ppgan/utils/options.py @@ -17,7 +17,8 @@ import argparse def parse_args(): parser = argparse.ArgumentParser(description='PaddleGAN') - parser.add_argument('--config-file', + parser.add_argument('-c', + '--config-file', metavar="FILE", help='config file path') # cuda setting diff --git a/ppgan/utils/setup.py b/ppgan/utils/setup.py index 588260e2def485b5bdc1aaa2ab2b115ca8fad68d..e37bde59793e33160edee56368ce9c817223d3de 100644 --- a/ppgan/utils/setup.py +++ b/ppgan/utils/setup.py @@ -26,8 +26,10 @@ def setup(args, cfg): cfg.is_train = True cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) - cfg.output_dir = os.path.join(cfg.output_dir, - str(cfg.model.name) + cfg.timestamp) + cfg.output_dir = os.path.join( + cfg.output_dir, + os.path.splitext(os.path.basename(str(args.config_file)))[0] + + cfg.timestamp) logger = setup_logger(cfg.output_dir)