diff --git a/ppgan/models/basicvsr_model.py b/ppgan/models/basicvsr_model.py index b47effb82f5186aaf93cc98c542cc91e38f98c59..70b3eb1075b3dc4e7193da7a802e7d1a7825a8f1 100644 --- a/ppgan/models/basicvsr_model.py +++ b/ppgan/models/basicvsr_model.py @@ -18,7 +18,7 @@ import paddle.nn as nn from .builder import MODELS from .sr_model import BaseSRModel from .generators.iconvsr import EDVRFeatureExtractor -from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet +from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet from ..modules.init import reset_parameters from ..utils.visual import tensor2img @@ -74,10 +74,10 @@ class BasicVSRModel(BaseSRModel): self.visual_items['output'] = self.output[:, 0, :, :, :] # pixel loss loss_pixel = self.pixel_criterion(self.output, self.gt) - + loss_pixel.backward() optims['optim'].step() - + self.losses['loss_pixel'] = loss_pixel self.current_iter += 1 @@ -91,8 +91,11 @@ class BasicVSRModel(BaseSRModel): out_img = [] gt_img = [] - for out_tensor, gt_tensor in zip(self.output[0], self.gt[0]): - # print(out_tensor.shape, gt_tensor.shape) + + _, t, _, _, _ = self.gt.shape + for i in range(t): + out_tensor = self.output[0, i] + gt_tensor = self.gt[0, i] out_img.append(tensor2img(out_tensor, (0., 1.))) gt_img.append(tensor2img(gt_tensor, (0., 1.))) @@ -103,10 +106,12 @@ class BasicVSRModel(BaseSRModel): def init_basicvsr_weight(net): for m in net.children(): - if hasattr(m, 'weight') and not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D)): + if hasattr(m, + 'weight') and not isinstance(m, + (nn.BatchNorm, nn.BatchNorm2D)): reset_parameters(m) continue - if (not isinstance( - m, (ResidualBlockNoBN, PixelShufflePack, SPyNet, EDVRFeatureExtractor))): + if (not isinstance(m, (ResidualBlockNoBN, PixelShufflePack, SPyNet, + EDVRFeatureExtractor))): init_basicvsr_weight(m)