From 9cbcf94a07759b492927b5446d1114ef192a1f33 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Tue, 10 Aug 2021 15:56:56 +0800 Subject: [PATCH] fix basicvsr bug on windows platform (#389) --- ppgan/models/basicvsr_model.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/ppgan/models/basicvsr_model.py b/ppgan/models/basicvsr_model.py index b47effb..70b3eb1 100644 --- a/ppgan/models/basicvsr_model.py +++ b/ppgan/models/basicvsr_model.py @@ -18,7 +18,7 @@ import paddle.nn as nn from .builder import MODELS from .sr_model import BaseSRModel from .generators.iconvsr import EDVRFeatureExtractor -from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet +from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet from ..modules.init import reset_parameters from ..utils.visual import tensor2img @@ -74,10 +74,10 @@ class BasicVSRModel(BaseSRModel): self.visual_items['output'] = self.output[:, 0, :, :, :] # pixel loss loss_pixel = self.pixel_criterion(self.output, self.gt) - + loss_pixel.backward() optims['optim'].step() - + self.losses['loss_pixel'] = loss_pixel self.current_iter += 1 @@ -91,8 +91,11 @@ class BasicVSRModel(BaseSRModel): out_img = [] gt_img = [] - for out_tensor, gt_tensor in zip(self.output[0], self.gt[0]): - # print(out_tensor.shape, gt_tensor.shape) + + _, t, _, _, _ = self.gt.shape + for i in range(t): + out_tensor = self.output[0, i] + gt_tensor = self.gt[0, i] out_img.append(tensor2img(out_tensor, (0., 1.))) gt_img.append(tensor2img(gt_tensor, (0., 1.))) @@ -103,10 +106,12 @@ class BasicVSRModel(BaseSRModel): def init_basicvsr_weight(net): for m in net.children(): - if hasattr(m, 'weight') and not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D)): + if hasattr(m, + 'weight') and not isinstance(m, + (nn.BatchNorm, nn.BatchNorm2D)): reset_parameters(m) continue - if (not isinstance( - m, (ResidualBlockNoBN, PixelShufflePack, SPyNet, EDVRFeatureExtractor))): + if (not isinstance(m, (ResidualBlockNoBN, PixelShufflePack, SPyNet, + EDVRFeatureExtractor))): init_basicvsr_weight(m) -- GitLab