未验证 提交 9cbcf94a 编写于 作者: L LielinJiang 提交者: GitHub

fix basicvsr bug on windows platform (#389)

上级 5996700b
......@@ -91,8 +91,11 @@ class BasicVSRModel(BaseSRModel):
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output[0], self.gt[0]):
# print(out_tensor.shape, gt_tensor.shape)
_, t, _, _, _ = self.gt.shape
for i in range(t):
out_tensor = self.output[0, i]
gt_tensor = self.gt[0, i]
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
......@@ -103,10 +106,12 @@ class BasicVSRModel(BaseSRModel):
def init_basicvsr_weight(net):
for m in net.children():
if hasattr(m, 'weight') and not isinstance(m, (nn.BatchNorm, nn.BatchNorm2D)):
if hasattr(m,
'weight') and not isinstance(m,
(nn.BatchNorm, nn.BatchNorm2D)):
reset_parameters(m)
continue
if (not isinstance(
m, (ResidualBlockNoBN, PixelShufflePack, SPyNet, EDVRFeatureExtractor))):
if (not isinstance(m, (ResidualBlockNoBN, PixelShufflePack, SPyNet,
EDVRFeatureExtractor))):
init_basicvsr_weight(m)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册