未验证 提交 9cbcf94a 编写于 作者: L LielinJiang 提交者: GitHub

fix basicvsr bug on windows platform (#389)

上级 5996700b
......@@ -18,7 +18,7 @@ import paddle.nn as nn
from .builder import MODELS
from .sr_model import BaseSRModel
from .generators.iconvsr import EDVRFeatureExtractor
from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet
from .generators.basicvsr import ResidualBlockNoBN, PixelShufflePack, SPyNet
from ..modules.init import reset_parameters
from ..utils.visual import tensor2img
......@@ -74,10 +74,10 @@ class BasicVSRModel(BaseSRModel):
self.visual_items['output'] = self.output[:, 0, :, :, :]
# pixel loss
loss_pixel = self.pixel_criterion(self.output, self.gt)
loss_pixel.backward()
optims['optim'].step()
self.losses['loss_pixel'] = loss_pixel
self.current_iter += 1
......@@ -91,8 +91,11 @@ class BasicVSRModel(BaseSRModel):
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output[0], self.gt[0]):
# print(out_tensor.shape, gt_tensor.shape)
_, t, _, _, _ = self.gt.shape
for i in range(t):
out_tensor = self.output[0, i]
gt_tensor = self.gt[0, i]
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
......@@ -103,10 +106,12 @@ class BasicVSRModel(BaseSRModel):
def init_basicvsr_weight(net):
for m in net.children():
if hasattr(m, 'weight') and not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D)):
if hasattr(m,
'weight') and not isinstance(m,
(nn.BatchNorm, nn.BatchNorm2D)):
reset_parameters(m)
continue
if (not isinstance(
m, (ResidualBlockNoBN, PixelShufflePack, SPyNet, EDVRFeatureExtractor))):
if (not isinstance(m, (ResidualBlockNoBN, PixelShufflePack, SPyNet,
EDVRFeatureExtractor))):
init_basicvsr_weight(m)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册