未验证 提交 5170bab1 编写于 作者: L LielinJiang 提交者: GitHub

reduce cuda memory for basicvsr (#393)

上级 95b8187f
...@@ -83,10 +83,11 @@ class BasicVSRModel(BaseSRModel): ...@@ -83,10 +83,11 @@ class BasicVSRModel(BaseSRModel):
self.current_iter += 1 self.current_iter += 1
def test_iter(self, metrics=None): def test_iter(self, metrics=None):
self.gt = self.gt.cpu()
self.nets['generator'].eval() self.nets['generator'].eval()
with paddle.no_grad(): with paddle.no_grad():
self.output = self.nets['generator'](self.lq) output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output[:, 0, :, :, :] self.visual_items['output'] = output[:, 0, :, :, :].cpu()
self.nets['generator'].train() self.nets['generator'].train()
out_img = [] out_img = []
...@@ -94,7 +95,7 @@ class BasicVSRModel(BaseSRModel): ...@@ -94,7 +95,7 @@ class BasicVSRModel(BaseSRModel):
_, t, _, _, _ = self.gt.shape _, t, _, _, _ = self.gt.shape
for i in range(t): for i in range(t):
out_tensor = self.output[0, i] out_tensor = output[0, i]
gt_tensor = self.gt[0, i] gt_tensor = self.gt[0, i]
out_img.append(tensor2img(out_tensor, (0., 1.))) out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.))) gt_img.append(tensor2img(gt_tensor, (0., 1.)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册