未验证 提交 85361824 编写于 作者: L LielinJiang 提交者: GitHub

Fix basicvsr export error (#729)

* fix basicvsr

* clean code
上级 7557835d
......@@ -192,6 +192,7 @@ class BaseModel(ABC):
for i in range(net["inputs_num"])
]
inputs_num = inputs_num + net["inputs_num"]
self.nets[net["name"]].export_mode = True
static_model = paddle.jit.to_static(self.nets[net["name"]],
input_spec=input_spec)
if output_dir is None:
......
......@@ -546,6 +546,22 @@ class BasicVSRNet(nn.Layer):
return flows_forward, flows_backward
def compute_flow_export(self, lrs):
"""export version of compute_flow
"""
n, t, c, h, w = lrs.shape
lrs_1 = lrs[:, :-1, :, :, :].reshape([-1, c, h, w])
lrs_2 = lrs[:, 1:, :, :, :].reshape([-1, c, h, w])
flows_backward = self.spynet(lrs_1, lrs_2).reshape([n, t - 1, 2, h, w])
flows_forward = self.spynet(lrs_2,
lrs_1).reshape([n, t - 1, 2, h, w])
return flows_forward, flows_backward
def forward(self, lrs):
"""Forward function for BasicVSR.
......@@ -566,7 +582,10 @@ class BasicVSRNet(nn.Layer):
self.check_if_mirror_extended(lrs)
# compute optical flow
flows_forward, flows_backward = self.compute_flow(lrs)
if hasattr(self, 'export_mode') and self.export_mode is True:
flows_forward, flows_backward = self.compute_flow_export(lrs)
else:
flows_forward, flows_backward = self.compute_flow(lrs)
# backward-time propgation
outputs = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册