diff --git a/ppgan/models/basicvsr_model.py b/ppgan/models/basicvsr_model.py index 70b3eb1075b3dc4e7193da7a802e7d1a7825a8f1..faa9911d913d71d63955c7ce935982dedbf2da87 100644 --- a/ppgan/models/basicvsr_model.py +++ b/ppgan/models/basicvsr_model.py @@ -83,10 +83,11 @@ class BasicVSRModel(BaseSRModel): self.current_iter += 1 def test_iter(self, metrics=None): + self.gt = self.gt.cpu() self.nets['generator'].eval() with paddle.no_grad(): - self.output = self.nets['generator'](self.lq) - self.visual_items['output'] = self.output[:, 0, :, :, :] + output = self.nets['generator'](self.lq) + self.visual_items['output'] = output[:, 0, :, :, :].cpu() self.nets['generator'].train() out_img = [] @@ -94,7 +95,7 @@ class BasicVSRModel(BaseSRModel): _, t, _, _, _ = self.gt.shape for i in range(t): - out_tensor = self.output[0, i] + out_tensor = output[0, i] gt_tensor = self.gt[0, i] out_img.append(tensor2img(out_tensor, (0., 1.))) gt_img.append(tensor2img(gt_tensor, (0., 1.)))