From 85361824baa1425253afd535dbfb9d1d0d9a806d Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Wed, 23 Nov 2022 23:26:59 +0800 Subject: [PATCH] Fix basicvsr export error (#729) * fix basicvsr * clean code --- ppgan/models/base_model.py | 1 + ppgan/models/generators/basicvsr.py | 21 ++++++++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/ppgan/models/base_model.py b/ppgan/models/base_model.py index d4cd3ac..ce1ee52 100755 --- a/ppgan/models/base_model.py +++ b/ppgan/models/base_model.py @@ -192,6 +192,7 @@ class BaseModel(ABC): for i in range(net["inputs_num"]) ] inputs_num = inputs_num + net["inputs_num"] + self.nets[net["name"]].export_mode = True static_model = paddle.jit.to_static(self.nets[net["name"]], input_spec=input_spec) if output_dir is None: diff --git a/ppgan/models/generators/basicvsr.py b/ppgan/models/generators/basicvsr.py index d7ccbc8..b57290e 100644 --- a/ppgan/models/generators/basicvsr.py +++ b/ppgan/models/generators/basicvsr.py @@ -546,6 +546,22 @@ class BasicVSRNet(nn.Layer): return flows_forward, flows_backward + def compute_flow_export(self, lrs): + """export version of compute_flow + """ + + n, t, c, h, w = lrs.shape + + lrs_1 = lrs[:, :-1, :, :, :].reshape([-1, c, h, w]) + lrs_2 = lrs[:, 1:, :, :, :].reshape([-1, c, h, w]) + + flows_backward = self.spynet(lrs_1, lrs_2).reshape([n, t - 1, 2, h, w]) + + flows_forward = self.spynet(lrs_2, + lrs_1).reshape([n, t - 1, 2, h, w]) + + return flows_forward, flows_backward + def forward(self, lrs): """Forward function for BasicVSR. @@ -566,7 +582,10 @@ class BasicVSRNet(nn.Layer): self.check_if_mirror_extended(lrs) # compute optical flow - flows_forward, flows_backward = self.compute_flow(lrs) + if hasattr(self, 'export_mode') and self.export_mode is True: + flows_forward, flows_backward = self.compute_flow_export(lrs) + else: + flows_forward, flows_backward = self.compute_flow(lrs) # backward-time propgation outputs = [] -- GitLab