diff --git a/ppgan/models/base_model.py b/ppgan/models/base_model.py index d4cd3ac140ecff9991cce6607091c8f8040501ae..ce1ee528add6e5068090a287b6264188f53f6c50 100755 --- a/ppgan/models/base_model.py +++ b/ppgan/models/base_model.py @@ -192,6 +192,7 @@ class BaseModel(ABC): for i in range(net["inputs_num"]) ] inputs_num = inputs_num + net["inputs_num"] + self.nets[net["name"]].export_mode = True static_model = paddle.jit.to_static(self.nets[net["name"]], input_spec=input_spec) if output_dir is None: diff --git a/ppgan/models/generators/basicvsr.py b/ppgan/models/generators/basicvsr.py index d7ccbc8b427c7e4f0b7829e9550e0939660a2854..b57290e90860306057f1270183f7ce8adc40e401 100644 --- a/ppgan/models/generators/basicvsr.py +++ b/ppgan/models/generators/basicvsr.py @@ -546,6 +546,22 @@ class BasicVSRNet(nn.Layer): return flows_forward, flows_backward + def compute_flow_export(self, lrs): + """export version of compute_flow + """ + + n, t, c, h, w = lrs.shape + + lrs_1 = lrs[:, :-1, :, :, :].reshape([-1, c, h, w]) + lrs_2 = lrs[:, 1:, :, :, :].reshape([-1, c, h, w]) + + flows_backward = self.spynet(lrs_1, lrs_2).reshape([n, t - 1, 2, h, w]) + + flows_forward = self.spynet(lrs_2, + lrs_1).reshape([n, t - 1, 2, h, w]) + + return flows_forward, flows_backward + def forward(self, lrs): """Forward function for BasicVSR. @@ -566,7 +582,10 @@ class BasicVSRNet(nn.Layer): self.check_if_mirror_extended(lrs) # compute optical flow - flows_forward, flows_backward = self.compute_flow(lrs) + if hasattr(self, 'export_mode') and self.export_mode is True: + flows_forward, flows_backward = self.compute_flow_export(lrs) + else: + flows_forward, flows_backward = self.compute_flow(lrs) # backward-time propgation outputs = []