From 89dbb63f670fdcff4f4c641ca4778df8d300769a Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Wed, 6 Jan 2021 10:49:10 +0800 Subject: [PATCH] Fix some bugs (#140) * fix some bugs * update configs --- configs/cyclegan_cityscapes.yaml | 2 +- configs/cyclegan_horse2zebra.yaml | 4 +-- configs/pix2pix_cityscapes.yaml | 2 +- configs/pix2pix_cityscapes_2gpus.yaml | 2 +- configs/pix2pix_facades.yaml | 2 +- ppgan/apps/realsr_predictor.py | 3 ++- ppgan/engine/trainer.py | 37 +++++++++++++++++---------- ppgan/utils/options.py | 3 ++- ppgan/utils/setup.py | 6 +++-- 9 files changed, 38 insertions(+), 23 deletions(-) diff --git a/configs/cyclegan_cityscapes.yaml b/configs/cyclegan_cityscapes.yaml index e4e4020..4f5257b 100644 --- a/configs/cyclegan_cityscapes.yaml +++ b/configs/cyclegan_cityscapes.yaml @@ -67,7 +67,7 @@ dataset: batch_size: 1 max_size: inf is_train: False - load_pipeline: + preprocess: - name: LoadImageFromFile key: A - name: LoadImageFromFile diff --git a/configs/cyclegan_horse2zebra.yaml b/configs/cyclegan_horse2zebra.yaml index 924c8e4..15e449e 100644 --- a/configs/cyclegan_horse2zebra.yaml +++ b/configs/cyclegan_horse2zebra.yaml @@ -35,7 +35,7 @@ dataset: batch_size: 1 is_train: True max_size: inf - load_pipeline: + preprocess: - name: LoadImageFromFile key: A - name: LoadImageFromFile @@ -67,7 +67,7 @@ dataset: batch_size: 1 max_size: inf is_train: False - load_pipeline: + preprocess: - name: LoadImageFromFile key: A - name: LoadImageFromFile diff --git a/configs/pix2pix_cityscapes.yaml b/configs/pix2pix_cityscapes.yaml index 0a3fc63..5a24315 100644 --- a/configs/pix2pix_cityscapes.yaml +++ b/configs/pix2pix_cityscapes.yaml @@ -61,7 +61,7 @@ dataset: dataroot: data/cityscapes/test num_workers: 4 batch_size: 1 - load_pipeline: + preprocess: - name: LoadImageFromFile key: pair - name: SplitPairedImage diff --git a/configs/pix2pix_cityscapes_2gpus.yaml b/configs/pix2pix_cityscapes_2gpus.yaml index 32bf846..6686cfe 100644 --- a/configs/pix2pix_cityscapes_2gpus.yaml +++ b/configs/pix2pix_cityscapes_2gpus.yaml @@ -61,7 +61,7 @@ dataset: dataroot: data/cityscapes/test num_workers: 4 batch_size: 1 - load_pipeline: + preprocess: - name: LoadImageFromFile key: pair - name: Transforms diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index 0e044f9..c7ed537 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -61,7 +61,7 @@ dataset: dataroot: data/facades/test num_workers: 4 batch_size: 1 - load_pipeline: + preprocess: - name: LoadImageFromFile key: pair - name: Transforms diff --git a/ppgan/apps/realsr_predictor.py b/ppgan/apps/realsr_predictor.py index 759d810..19139a8 100644 --- a/ppgan/apps/realsr_predictor.py +++ b/ppgan/apps/realsr_predictor.py @@ -60,7 +60,8 @@ class RealSRPredictor(BasePredictor): img = self.norm(ori_img) x = paddle.to_tensor(img[np.newaxis, ...]) - out = self.model(x) + with paddle.no_grad(): + out = self.model(x) pred_img = self.denorm(out.numpy()[0]) pred_img = Image.fromarray(pred_img) diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 2db728c..ab75cb0 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -124,6 +124,9 @@ class Trainer: self.weight_interval = cfg.snapshot_config.interval self.log_interval = cfg.log_config.interval self.visual_interval = cfg.log_config.visiual_interval + if self.by_epoch: + self.weight_interval *= self.iters_per_epoch + self.validate_interval = -1 if cfg.get('validate', None) is not None: self.validate_interval = cfg.validate.get('interval', -1) @@ -177,16 +180,12 @@ class Trainer: self.model.lr_scheduler.step() - if self.by_epoch: - temp = self.current_epoch - else: - temp = self.current_iter - if self.validate_interval > -1 and temp % self.validate_interval == 0: + if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0: self.test() - if temp % self.weight_interval == 0: - self.save(temp, 'weight', keep=-1) - self.save(temp) + if self.current_iter % self.weight_interval == 0: + self.save(self.current_iter, 'weight', keep=-1) + self.save(self.current_iter) self.current_iter += 1 @@ -335,7 +334,12 @@ class Trainer: assert name in ['checkpoint', 'weight'] state_dicts = {} - save_filename = 'epoch_%s_%s.pdparams' % (epoch, name) + if self.by_epoch: + save_filename = 'epoch_%s_%s.pdparams' % ( + epoch // self.iters_per_epoch, name) + else: + save_filename = 'iter_%s_%s.pdparams' % (epoch, name) + save_path = os.path.join(self.output_dir, save_filename) for net_name, net in self.model.nets.items(): state_dicts[net_name] = net.state_dict() @@ -353,9 +357,16 @@ class Trainer: if keep > 0: try: - checkpoint_name_to_be_removed = os.path.join( - self.output_dir, - 'epoch_%s_%s.pdparams' % (epoch - keep, name)) + if self.by_epoch: + checkpoint_name_to_be_removed = os.path.join( + self.output_dir, 'epoch_%s_%s.pdparams' % + ((epoch - keep * self.weight_interval) // + self.iters_per_epoch, name)) + else: + checkpoint_name_to_be_removed = os.path.join( + self.output_dir, 'iter_%s_%s.pdparams' % + (epoch - keep * self.weight_interval, name)) + if os.path.exists(checkpoint_name_to_be_removed): os.remove(checkpoint_name_to_be_removed) @@ -366,7 +377,7 @@ class Trainer: state_dicts = load(checkpoint_path) if state_dicts.get('epoch', None) is not None: self.start_epoch = state_dicts['epoch'] + 1 - self.global_steps = self.steps_per_epoch * state_dicts['epoch'] + self.global_steps = self.iters_per_epoch * state_dicts['epoch'] for net_name, net in self.model.nets.items(): net.set_state_dict(state_dicts[net_name]) diff --git a/ppgan/utils/options.py b/ppgan/utils/options.py index 9aff161..e87477a 100644 --- a/ppgan/utils/options.py +++ b/ppgan/utils/options.py @@ -17,7 +17,8 @@ import argparse def parse_args(): parser = argparse.ArgumentParser(description='PaddleGAN') - parser.add_argument('--config-file', + parser.add_argument('-c', + '--config-file', metavar="FILE", help='config file path') # cuda setting diff --git a/ppgan/utils/setup.py b/ppgan/utils/setup.py index 588260e..e37bde5 100644 --- a/ppgan/utils/setup.py +++ b/ppgan/utils/setup.py @@ -26,8 +26,10 @@ def setup(args, cfg): cfg.is_train = True cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) - cfg.output_dir = os.path.join(cfg.output_dir, - str(cfg.model.name) + cfg.timestamp) + cfg.output_dir = os.path.join( + cfg.output_dir, + os.path.splitext(os.path.basename(str(args.config_file)))[0] + + cfg.timestamp) logger = setup_logger(cfg.output_dir) -- GitLab