diff --git a/configs/mprnet_deblurring.yaml b/configs/mprnet_deblurring.yaml index a69fb02f4fb7eefe5a97bb5adbd6723086effc68..4082a6aa0158b2c74a786d0e8098fc9203722430 100644 --- a/configs/mprnet_deblurring.yaml +++ b/configs/mprnet_deblurring.yaml @@ -1,4 +1,5 @@ -total_iters: 100000 +# epoch: 3000 for total batch size=16 +total_iters: 400000 output_dir: output_dir model: @@ -15,38 +16,38 @@ dataset: train: name: MPRTrain rgb_dir: 'data/GoPro/train' - num_workers: 16 - batch_size: 4 + num_workers: 4 + batch_size: 2 img_options: patch_size: 256 test: - name: MPRTrain + name: MPRVal rgb_dir: 'data/GoPro/test' - num_workers: 16 - batch_size: 4 + num_workers: 4 + batch_size: 2 img_options: patch_size: 256 lr_scheduler: name: CosineAnnealingRestartLR - learning_rate: !!float 2e-4 - periods: [25000, 25000, 25000, 25000] - restart_weights: [1, 1, 1, 1] + learning_rate: !!float 1e-4 + periods: [400000] + restart_weights: [1] eta_min: !!float 1e-6 validate: - interval: 10 + interval: 5000 save_img: false metrics: psnr: # metric name, can be arbitrary name: PSNR crop_border: 4 - test_y_channel: True + test_y_channel: false ssim: name: SSIM crop_border: 4 - test_y_channel: True + test_y_channel: false optimizer: name: Adam @@ -59,7 +60,7 @@ optimizer: epsilon: 1e-8 log_config: - interval: 10 + interval: 100 visiual_interval: 5000 snapshot_config: diff --git a/docs/en_US/tutorials/single_image_super_resolution.md b/docs/en_US/tutorials/single_image_super_resolution.md index 44a7e36b23b9747164f487735f958cfa50477722..053fc2f7026cfd369dead2f57bead5e0351c5066 100644 --- a/docs/en_US/tutorials/single_image_super_resolution.md +++ b/docs/en_US/tutorials/single_image_super_resolution.md @@ -130,6 +130,10 @@ The metrics are PSNR / SSIM. | pan_x4 | 30.4574 / 0.8643 | 26.7204 / 0.7434 | 28.9187 / 0.8176 | | drns_x4 | 32.6684 / 0.8999 | 28.9037 / 0.7885 | - | +Deblur models zoo +| model | GoPro | Download Link | +|---|---|---| +| MPRNet | 33.4360 / 0.9410 | [link](https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams) | diff --git a/docs/zh_CN/tutorials/single_image_super_resolution.md b/docs/zh_CN/tutorials/single_image_super_resolution.md index 7572584a66ef580e57cb9fe303ac283db48b620d..63e5f2b5c8415d33804b5169bbffd7566bbf920b 100644 --- a/docs/zh_CN/tutorials/single_image_super_resolution.md +++ b/docs/zh_CN/tutorials/single_image_super_resolution.md @@ -120,6 +120,10 @@ paddle模型使用DIV2K数据集训练,torch模型使用df2k和DIV2K数据集 | paddle | 30.4574 / 0.8643 | 26.7204 / 0.7434 | | torch | 30.2183 / 0.8643 | 26.8035 / 0.7445 | +去模糊模型 +| 模型 | GoPro | 下载地址 | +|---|---|---| +| MPRNet | 33.4360 / 0.9410 | [链接](https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams) | diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 9184e641b991d303b7fc781998e5f3944496c82a..417db0a37fe2d1035802ab5d54b7b31deba63c6e 100755 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -29,6 +29,7 @@ from ..utils.filesystem import makedirs, save, load from ..utils.timer import TimeAverager from ..utils.profiler import add_profiler_step + class IterLoader: def __init__(self, dataloader): self._dataloader = dataloader @@ -429,15 +430,35 @@ class Trainer: def load(self, weight_path): state_dicts = load(weight_path) - for net_name, net in self.model.nets.items(): - if net_name in state_dicts: - net.set_state_dict(state_dicts[net_name]) - self.logger.info( - 'Loaded pretrained weight for net {}'.format(net_name)) + def is_dict_in_dict_weight(state_dict): + if isinstance(state_dict, dict) and len(state_dict) > 0: + val = list(state_dict.values())[0] + if isinstance(val, dict): + return True + else: + return False else: - self.logger.warning( - 'Can not find state dict of net {}. Skip load pretrained weight for net {}' - .format(net_name, net_name)) + return False + + if is_dict_in_dict_weight(state_dicts): + for net_name, net in self.model.nets.items(): + if net_name in state_dicts: + net.set_state_dict(state_dicts[net_name]) + self.logger.info( + 'Loaded pretrained weight for net {}'.format(net_name)) + else: + self.logger.warning( + 'Can not find state dict of net {}. Skip load pretrained weight for net {}' + .format(net_name, net_name)) + else: + assert len(self.model.nets + ) == 1, 'checkpoint only contain weight of one net, \ + but model contains more than one net!' + + net_name, net = list(self.model.nets.items())[0] + net.set_state_dict(state_dicts) + self.logger.info( + 'Loaded pretrained weight for net {}'.format(net_name)) def close(self): """ diff --git a/ppgan/models/criterions/pixel_loss.py b/ppgan/models/criterions/pixel_loss.py index 6e878ad735306222b8a3a12bb85f3ee26c1c992b..ca60f55f96067815157a7c831d40f100f773bd2c 100644 --- a/ppgan/models/criterions/pixel_loss.py +++ b/ppgan/models/criterions/pixel_loss.py @@ -249,23 +249,25 @@ class CalcStyleLoss(): class EdgeLoss(): def __init__(self): k = paddle.to_tensor([[.05, .25, .4, .25, .05]]) - self.kernel = paddle.matmul(k.t(),k).unsqueeze(0).tile([3,1,1,1]) + self.kernel = paddle.matmul(k.t(), k).unsqueeze(0).tile([3, 1, 1, 1]) self.loss = CharbonnierLoss() def conv_gauss(self, img): n_channels, _, kw, kh = self.kernel.shape - img = F.pad(img, [kw//2, kh//2, kw//2, kh//2], mode='replicate') + img = F.pad(img, [kw // 2, kh // 2, kw // 2, kh // 2], mode='replicate') return F.conv2d(img, self.kernel, groups=n_channels) def laplacian_kernel(self, current): - filtered = self.conv_gauss(current) # filter - down = filtered[:,:,::2,::2] # downsample - new_filter = paddle.zeros_like(filtered) - new_filter[:,:,::2,::2] = down*4 # upsample - filtered = self.conv_gauss(new_filter) # filter + filtered = self.conv_gauss(current) # filter + down = filtered[:, :, ::2, ::2] # downsample + new_filter = paddle.zeros_like(filtered) + new_filter.stop_gradient = True + new_filter[:, :, ::2, ::2] = down * 4 # upsample + filtered = self.conv_gauss(new_filter) # filter diff = current - filtered return diff def __call__(self, x, y): + y.stop_gradient = True loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) - return loss \ No newline at end of file + return loss diff --git a/ppgan/models/mpr_model.py b/ppgan/models/mpr_model.py index d88e8f11f441c94c03f295115e111a37ef994880..426b9c3eec16f33990dec2b85e48b8a77b1c445c 100644 --- a/ppgan/models/mpr_model.py +++ b/ppgan/models/mpr_model.py @@ -20,6 +20,7 @@ from .base_model import BaseModel from .generators.builder import build_generator from .criterions.builder import build_criterion from ..modules.init import reset_parameters, init_weights +from ..utils.visual import tensor2img @MODELS.register() @@ -50,12 +51,12 @@ class MPRModel(BaseModel): def setup_input(self, input): self.target = input[0] - self.input_ = input[1] + self.lq = input[1] def train_iter(self, optims=None): optims['optim'].clear_gradients() - restored = self.nets['generator'](self.input_) + restored = self.nets['generator'](self.lq) loss_char = [] loss_edge = [] @@ -75,5 +76,21 @@ class MPRModel(BaseModel): self.losses['loss'] = loss.numpy() def forward(self): - """Run forward pass; called by both functions and .""" pass + + def test_iter(self, metrics=None): + self.nets['generator'].eval() + with paddle.no_grad(): + self.output = self.nets['generator'](self.lq)[0] + self.visual_items['output'] = self.output + self.nets['generator'].train() + + out_img = [] + gt_img = [] + for out_tensor, gt_tensor in zip(self.output, self.target): + out_img.append(tensor2img(out_tensor, (0., 1.))) + gt_img.append(tensor2img(gt_tensor, (0., 1.))) + + if metrics is not None: + for metric in metrics.values(): + metric.update(out_img, gt_img)