diff --git a/applications/DAIN/predict.py b/applications/DAIN/predict.py index 6c6e5234fa1584e3704f8873a3cfa658ad689e71..38c1d6baa4c833da966baf30b0ef5ee5e7e4f4ef 100644 --- a/applications/DAIN/predict.py +++ b/applications/DAIN/predict.py @@ -8,6 +8,7 @@ import time import glob import numpy as np from imageio import imread, imsave +from tqdm import tqdm import cv2 import paddle.fluid as fluid @@ -175,8 +176,7 @@ class VideoFrameInterp(object): if not os.path.exists(os.path.join(frame_path_combined, vidname)): os.makedirs(os.path.join(frame_path_combined, vidname)) - for i in range(frame_num - 1): - print(frames[i]) + for i in tqdm(range(frame_num - 1)): first = frames[i] second = frames[i + 1] @@ -208,12 +208,10 @@ class VideoFrameInterp(object): assert (X0.shape[1] == X1.shape[1]) assert (X0.shape[2] == X1.shape[2]) - print("size before padding ", X0.shape) X0 = np.pad(X0, ((0,0), (padding_top, padding_bottom), \ (padding_left, padding_right)), mode='edge') X1 = np.pad(X1, ((0,0), (padding_top, padding_bottom), \ (padding_left, padding_right)), mode='edge') - print("size after padding ", X0.shape) X0 = np.expand_dims(X0, axis=0) X1 = np.expand_dims(X1, axis=0) @@ -233,8 +231,6 @@ class VideoFrameInterp(object): proc_timer.update(time.time() - proc_end) tot_timer.update(time.time() - end) end = time.time() - print("*********** current image process time \t " + - str(time.time() - proc_end) + "s *********") y_ = [ np.transpose( diff --git a/applications/DAIN/pwcnet/pwcnet.py b/applications/DAIN/pwcnet/pwcnet.py index 75bd7e4cb0071e9342814e9d165abe589ea8de8e..effdc623b4bf8527bf542ffec5ebb5d07d4a81b4 100644 --- a/applications/DAIN/pwcnet/pwcnet.py +++ b/applications/DAIN/pwcnet/pwcnet.py @@ -17,8 +17,7 @@ import numpy as np import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import Conv2D, Conv2DTranspose - -from .correlation_op.correlation import correlation +from paddle.fluid.contrib import correlation __all__ = ['pwc_dc_net'] diff --git a/applications/DAIN/util.py b/applications/DAIN/util.py index 3b83a1bb7aeff9a6a5dc04d32b57f9ccfa491ea5..3efbfe0dc7cac0aeed1c624af9c192c381b4fdc5 100644 --- a/applications/DAIN/util.py +++ b/applications/DAIN/util.py @@ -22,7 +22,7 @@ class AverageMeter(object): def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): - ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] vid_name = vid_path.split('/')[-1].split('.')[0] out_full_path = os.path.join(outpath, vid_name) @@ -55,30 +55,29 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] cmd = ''.join(cmd) - print(cmd) + if os.system(cmd) == 0: - print('Video: {} done'.format(vid_name)) + pass else: - print('Video: {} error'.format(vid_name)) - print('') + print('ffmpeg process video: {} error'.format(vid_name)) + sys.stdout.flush() return out_full_path def frames_to_video_ffmpeg(framepath, videopath, r): - ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] cmd = ffmpeg + [ ' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ', ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath ] cmd = ''.join(cmd) - print(cmd) if os.system(cmd) == 0: - print('Video: {} done'.format(videopath)) + pass else: - print('Video: {} error'.format(videopath)) - print('') + print('ffmpeg process video: {} error'.format(videopath)) + sys.stdout.flush() @@ -99,7 +98,8 @@ def combine_frames(input, interpolated, combined, num_frames): for k in range(num_frames): src = frames2[i * num_frames + k] dst = os.path.join( - combined, '{:08d}.png'.format(i * (num_frames + 1) + k + 1)) + combined, + '{:08d}.png'.format(i * (num_frames + 1) + k + 1)) shutil.copy2(src, dst) except Exception as e: print(e) diff --git a/applications/DeOldify/hook.py b/applications/DeOldify/hook.py index ebc75a6f59b187a379cb080344eac77608adb09a..a9ad7d488a6d87b1dd27934dea59cd05ad70ca66 100644 --- a/applications/DeOldify/hook.py +++ b/applications/DeOldify/hook.py @@ -3,14 +3,16 @@ import numpy as np import paddle import paddle.nn as nn + def is_listy(x): - return isinstance(x, (tuple,list)) + return isinstance(x, (tuple, list)) class Hook(): "Create a hook on `m` with `hook_func`." + def __init__(self, m, hook_func, is_forward=True, detach=True): - self.hook_func,self.detach,self.stored = hook_func,detach,None + self.hook_func, self.detach, self.stored = hook_func, detach, None f = m.register_forward_post_hook if is_forward else m.register_backward_hook self.hook = f(self.hook_fn) self.removed = False @@ -18,64 +20,90 @@ class Hook(): def hook_fn(self, module, input, output): "Applies `hook_func` to `module`, `input`, `output`." if self.detach: - input = (o.detach() for o in input ) if is_listy(input ) else input.detach() - output = (o.detach() for o in output) if is_listy(output) else output.detach() + input = (o.detach() + for o in input) if is_listy(input) else input.detach() + output = (o.detach() + for o in output) if is_listy(output) else output.detach() self.stored = self.hook_func(module, input, output) def remove(self): "Remove the hook from the model." if not self.removed: self.hook.remove() - self.removed=True + self.removed = True + + def __enter__(self, *args): + return self + + def __exit__(self, *args): + self.remove() - def __enter__(self, *args): return self - def __exit__(self, *args): self.remove() class Hooks(): "Create several hooks on the modules in `ms` with `hook_func`." + def __init__(self, ms, hook_func, is_forward=True, detach=True): self.hooks = [] try: for m in ms: self.hooks.append(Hook(m, hook_func, is_forward, detach)) except Exception as e: - print(e) + pass + + def __getitem__(self, i: int) -> Hook: + return self.hooks[i] + + def __len__(self) -> int: + return len(self.hooks) + + def __iter__(self): + return iter(self.hooks) - def __getitem__(self,i:int)->Hook: return self.hooks[i] - def __len__(self)->int: return len(self.hooks) - def __iter__(self): return iter(self.hooks) @property - def stored(self): return [o.stored for o in self] + def stored(self): + return [o.stored for o in self] def remove(self): "Remove the hooks from the model." - for h in self.hooks: h.remove() + for h in self.hooks: + h.remove() - def __enter__(self, *args): return self - def __exit__ (self, *args): self.remove() + def __enter__(self, *args): + return self -def _hook_inner(m,i,o): return o if isinstance(o, paddle.framework.Variable) else o if is_listy(o) else list(o) + def __exit__(self, *args): + self.remove() -def hook_output (module, detach=True, grad=False): + +def _hook_inner(m, i, o): + return o if isinstance( + o, paddle.framework.Variable) else o if is_listy(o) else list(o) + + +def hook_output(module, detach=True, grad=False): "Return a `Hook` that stores activations of `module` in `self.stored`" return Hook(module, _hook_inner, detach=detach, is_forward=not grad) + def hook_outputs(modules, detach=True, grad=False): "Return `Hooks` that store activations of all `modules` in `self.stored`" return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad) -def model_sizes(m, size=(64,64)): + +def model_sizes(m, size=(64, 64)): "Pass a dummy input through the model `m` to get the various sizes of activations." with hook_outputs(m) as hooks: x = dummy_eval(m, size) return [o.stored.shape for o in hooks] -def dummy_eval(m, size=(64,64)): + +def dummy_eval(m, size=(64, 64)): "Pass a `dummy_batch` in evaluation mode in `m` with `size`." m.eval() return m(dummy_batch(size)) -def dummy_batch(size=(64,64), ch_in=3): + +def dummy_batch(size=(64, 64), ch_in=3): "Create a dummy batch to go through `m` with `size`." arr = np.random.rand(1, ch_in, *size).astype('float32') * 2 - 1 return paddle.to_tensor(arr) diff --git a/applications/DeOldify/model.py b/applications/DeOldify/model.py index f763eebef8ecd4edadf9d9b47ca3a6e0332ec5f1..9f97ed8667a70c248c8d6d075e4b4c7f05f186d0 100644 --- a/applications/DeOldify/model.py +++ b/applications/DeOldify/model.py @@ -5,14 +5,13 @@ import paddle.nn.functional as F from resnet_backbone import resnet34, resnet101 from hook import hook_outputs, model_sizes, dummy_eval -# from weight_norm import weight_norm from spectral_norm import Spectralnorm -from conv import Conv1D from paddle import fluid class SequentialEx(nn.Layer): "Like `nn.Sequential`, but with ModuleList semantics, and can access module input" + def __init__(self, *layers): super().__init__() self.layers = nn.LayerList(layers) @@ -28,14 +27,32 @@ class SequentialEx(nn.Layer): res = nres return res - def __getitem__(self,i): return self.layers[i] - def append(self,l): return self.layers.append(l) - def extend(self,l): return self.layers.extend(l) - def insert(self,i,l): return self.layers.insert(i,l) + def __getitem__(self, i): + return self.layers[i] + + def append(self, l): + return self.layers.append(l) + + def extend(self, l): + return self.layers.extend(l) + + def insert(self, i, l): + return self.layers.insert(i, l) class Deoldify(SequentialEx): - def __init__(self, encoder, n_classes, blur=False, blur_final=True, self_attention=False, y_range=None, last_cross=True, bottle=False, norm_type='Batch', nf_factor=1, **kwargs): + def __init__(self, + encoder, + n_classes, + blur=False, + blur_final=True, + self_attention=False, + y_range=None, + last_cross=True, + bottle=False, + norm_type='Batch', + nf_factor=1, + **kwargs): imsize = (256, 256) sfs_szs = model_sizes(encoder, size=imsize) @@ -47,12 +64,14 @@ class Deoldify(SequentialEx): extra_bn = norm_type == 'Spectral' ni = sfs_szs[-1][1] middle_conv = nn.Sequential( - custom_conv_layer( - ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn - ), - custom_conv_layer( - ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn - ), + custom_conv_layer(ni, + ni * 2, + norm_type=norm_type, + extra_bn=extra_bn), + custom_conv_layer(ni * 2, + ni, + norm_type=norm_type, + extra_bn=extra_bn), ) layers = [encoder, nn.BatchNorm(ni), nn.ReLU(), middle_conv] @@ -65,18 +84,16 @@ class Deoldify(SequentialEx): n_out = nf if not_final else nf // 2 - unet_block = UnetBlockWide( - up_in_c, - x_in_c, - n_out, - self.sfs[i], - final_div=not_final, - blur=blur, - self_attention=sa, - norm_type=norm_type, - extra_bn=extra_bn, - **kwargs - ) + unet_block = UnetBlockWide(up_in_c, + x_in_c, + n_out, + self.sfs[i], + final_div=not_final, + blur=blur, + self_attention=sa, + norm_type=norm_type, + extra_bn=extra_bn, + **kwargs) unet_block.eval() layers.append(unet_block) x = unet_block(x) @@ -87,32 +104,34 @@ class Deoldify(SequentialEx): if last_cross: layers.append(MergeLayer(dense=True)) ni += 3 - layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs)) + layers.append( + res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs)) layers += [ - custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type) + custom_conv_layer(ni, + n_classes, + ks=1, + use_activ=False, + norm_type=norm_type) ] if y_range is not None: layers.append(SigmoidRange(*y_range)) super().__init__(*layers) - -def custom_conv_layer( - ni: int, - nf: int, - ks: int = 3, - stride: int = 1, - padding: int = None, - bias: bool = None, - is_1d: bool = False, - norm_type='Batch', - use_activ: bool = True, - leaky: float = None, - transpose: bool = False, - self_attention: bool = False, - extra_bn: bool = False, - **kwargs -): +def custom_conv_layer(ni: int, + nf: int, + ks: int = 3, + stride: int = 1, + padding: int = None, + bias: bool = None, + is_1d: bool = False, + norm_type='Batch', + use_activ: bool = True, + leaky: float = None, + transpose: bool = False, + self_attention: bool = False, + extra_bn: bool = False, + **kwargs): "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers." if padding is None: padding = (ks - 1) // 2 if not transpose else 0 @@ -121,12 +140,15 @@ def custom_conv_layer( bias = not bn conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d - conv = conv_func(ni, nf, kernel_size=ks, bias_attr=bias, stride=stride, padding=padding) + conv = conv_func(ni, + nf, + kernel_size=ks, + bias_attr=bias, + stride=stride, + padding=padding) if norm_type == 'Weight': - print('use weight norm') conv = nn.utils.weight_norm(conv) elif norm_type == 'Spectral': - # pass conv = Spectralnorm(conv) layers = [conv] if use_activ: @@ -135,11 +157,11 @@ def custom_conv_layer( layers.append((nn.BatchNorm if is_1d else nn.BatchNorm)(nf)) if self_attention: layers.append(SelfAttention(nf)) - + return nn.Sequential(*layers) -def relu(inplace:bool=False, leaky:float=None): +def relu(inplace: bool = False, leaky: float = None): "Return a relu activation, maybe `leaky` and `inplace`." return nn.LeakyReLU(leaky) if leaky is not None else nn.ReLU() @@ -147,29 +169,31 @@ def relu(inplace:bool=False, leaky:float=None): class UnetBlockWide(nn.Layer): "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`." - def __init__( - self, - up_in_c: int, - x_in_c: int, - n_out: int, - hook, - final_div: bool = True, - blur: bool = False, - leaky: float = None, - self_attention: bool = False, - **kwargs - ): + def __init__(self, + up_in_c: int, + x_in_c: int, + n_out: int, + hook, + final_div: bool = True, + blur: bool = False, + leaky: float = None, + self_attention: bool = False, + **kwargs): super().__init__() self.hook = hook up_out = x_out = n_out // 2 - self.shuf = CustomPixelShuffle_ICNR( - up_in_c, up_out, blur=blur, leaky=leaky, **kwargs - ) + self.shuf = CustomPixelShuffle_ICNR(up_in_c, + up_out, + blur=blur, + leaky=leaky, + **kwargs) self.bn = nn.BatchNorm(x_in_c) ni = up_out + x_in_c - self.conv = custom_conv_layer( - ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs - ) + self.conv = custom_conv_layer(ni, + x_out, + leaky=leaky, + self_attention=self_attention, + **kwargs) self.relu = relu(leaky=leaky) def forward(self, up_in): @@ -186,29 +210,32 @@ class UnetBlockDeep(paddle.fluid.Layer): "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`." def __init__( - self, - up_in_c: int, - x_in_c: int, - # hook: Hook, - final_div: bool = True, - blur: bool = False, - leaky: float = None, - self_attention: bool = False, - nf_factor: float = 1.0, - **kwargs - ): + self, + up_in_c: int, + x_in_c: int, + # hook: Hook, + final_div: bool = True, + blur: bool = False, + leaky: float = None, + self_attention: bool = False, + nf_factor: float = 1.0, + **kwargs): super().__init__() - - self.shuf = CustomPixelShuffle_ICNR( - up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs - ) + + self.shuf = CustomPixelShuffle_ICNR(up_in_c, + up_in_c // 2, + blur=blur, + leaky=leaky, + **kwargs) self.bn = nn.BatchNorm(x_in_c) ni = up_in_c // 2 + x_in_c nf = int((ni if final_div else ni // 2) * nf_factor) self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs) - self.conv2 = custom_conv_layer( - nf, nf, leaky=leaky, self_attention=self_attention, **kwargs - ) + self.conv2 = custom_conv_layer(nf, + nf, + leaky=leaky, + self_attention=self_attention, + **kwargs) self.relu = relu(leaky=leaky) def forward(self, up_in): @@ -228,34 +255,61 @@ def ifnone(a, b): class PixelShuffle_ICNR(nn.Layer): "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`." - def __init__(self, ni:int, nf:int=None, scale:int=2, blur:bool=False, norm_type='Weight', leaky:float=None): + + def __init__(self, + ni: int, + nf: int = None, + scale: int = 2, + blur: bool = False, + norm_type='Weight', + leaky: float = None): super().__init__() nf = ifnone(nf, ni) - self.conv = conv_layer(ni, nf*(scale**2), ks=1, norm_type=norm_type, use_activ=False) - + self.conv = conv_layer(ni, + nf * (scale**2), + ks=1, + norm_type=norm_type, + use_activ=False) + self.shuf = PixelShuffle(scale) - - self.pad = ReplicationPad2d((1,0,1,0)) + + self.pad = ReplicationPad2d((1, 0, 1, 0)) self.blur = nn.Pool2D(2, pool_stride=1, pool_type='avg') self.relu = relu(True, leaky=leaky) - def forward(self,x): + def forward(self, x): x = self.shuf(self.relu(self.conv(x))) return self.blur(self.pad(x)) if self.blur else x -def conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False, - norm_type='Batch', use_activ:bool=True, leaky:float=None, - transpose:bool=False, init=None, self_attention:bool=False): + +def conv_layer(ni: int, + nf: int, + ks: int = 3, + stride: int = 1, + padding: int = None, + bias: bool = None, + is_1d: bool = False, + norm_type='Batch', + use_activ: bool = True, + leaky: float = None, + transpose: bool = False, + init=None, + self_attention: bool = False): "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers." - if padding is None: padding = (ks-1)//2 if not transpose else 0 + if padding is None: padding = (ks - 1) // 2 if not transpose else 0 bn = norm_type in ('Batch', 'BatchZero') if bias is None: bias = not bn conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d - - conv = conv_func(ni, nf, kernel_size=ks, bias_attr=bias, stride=stride, padding=padding) - if norm_type=='Weight': + + conv = conv_func(ni, + nf, + kernel_size=ks, + bias_attr=bias, + stride=stride, + padding=padding) + if norm_type == 'Weight': conv = nn.utils.weight_norm(conv) - elif norm_type=='Spectral': + elif norm_type == 'Spectral': conv = Spectralnorm(conv) layers = [conv] @@ -268,26 +322,27 @@ def conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bo class CustomPixelShuffle_ICNR(paddle.fluid.Layer): "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`." - def __init__( - self, - ni: int, - nf: int = None, - scale: int = 2, - blur: bool = False, - leaky: float = None, - **kwargs - ): + def __init__(self, + ni: int, + nf: int = None, + scale: int = 2, + blur: bool = False, + leaky: float = None, + **kwargs): super().__init__() nf = ifnone(nf, ni) - self.conv = custom_conv_layer( - ni, nf * (scale ** 2), ks=1, use_activ=False, **kwargs - ) - + self.conv = custom_conv_layer(ni, + nf * (scale**2), + ks=1, + use_activ=False, + **kwargs) + self.shuf = PixelShuffle(scale) - + self.pad = ReplicationPad2d((1, 0, 1, 0)) self.blur = nn.Pool2D(2, pool_stride=1, pool_type='avg') - self.relu = nn.LeakyReLU(leaky) if leaky is not None else nn.ReLU()#relu(True, leaky=leaky) + self.relu = nn.LeakyReLU( + leaky) if leaky is not None else nn.ReLU() #relu(True, leaky=leaky) def forward(self, x): x = self.shuf(self.relu(self.conv(x))) @@ -296,34 +351,43 @@ class CustomPixelShuffle_ICNR(paddle.fluid.Layer): class MergeLayer(paddle.fluid.Layer): "Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`." - def __init__(self, dense:bool=False): + + def __init__(self, dense: bool = False): super().__init__() - self.dense=dense + self.dense = dense self.orig = None - def forward(self, x): - out = paddle.concat([x,self.orig], axis=1) if self.dense else (x+self.orig) + def forward(self, x): + out = paddle.concat([x, self.orig], + axis=1) if self.dense else (x + self.orig) self.orig = None return out -def res_block(nf, dense:bool=False, norm_type='Batch', bottle:bool=False, **conv_kwargs): +def res_block(nf, + dense: bool = False, + norm_type='Batch', + bottle: bool = False, + **conv_kwargs): "Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`." norm2 = norm_type - if not dense and (norm_type=='Batch'): norm2 = 'BatchZero' - nf_inner = nf//2 if bottle else nf - return SequentialEx(conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs), - conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs), - MergeLayer(dense)) + if not dense and (norm_type == 'Batch'): norm2 = 'BatchZero' + nf_inner = nf // 2 if bottle else nf + return SequentialEx( + conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs), + conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs), + MergeLayer(dense)) class SigmoidRange(paddle.fluid.Layer): "Sigmoid module with range `(low,x_max)`" + def __init__(self, low, high): super().__init__() - self.low,self.high = low,high + self.low, self.high = low, high - def forward(self, x): return sigmoid_range(x, self.low, self.high) + def forward(self, x): + return sigmoid_range(x, self.low, self.high) def sigmoid_range(x, low, high): @@ -331,7 +395,6 @@ def sigmoid_range(x, low, high): return F.sigmoid(x) * (high - low) + low - class PixelShuffle(paddle.fluid.Layer): def __init__(self, upscale_factor): super(PixelShuffle, self).__init__() @@ -349,7 +412,13 @@ class ReplicationPad2d(nn.Layer): def forward(self, x): return paddle.fluid.layers.pad2d(x, self.size, mode="edge") -def conv1d(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False): + +def conv1d(ni: int, + no: int, + ks: int = 1, + stride: int = 1, + padding: int = 0, + bias: bool = False): "Create and initialize a `nn.Conv1d` layer with spectral normalization." conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias_attr=bias) return Spectralnorm(conv) @@ -357,30 +426,35 @@ def conv1d(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=Fals class SelfAttention(nn.Layer): "Self attention layer for nd." + def __init__(self, n_channels): super().__init__() - self.query = conv1d(n_channels, n_channels//8) - self.key = conv1d(n_channels, n_channels//8) + self.query = conv1d(n_channels, n_channels // 8) + self.key = conv1d(n_channels, n_channels // 8) self.value = conv1d(n_channels, n_channels) - self.gamma = self.create_parameter(shape=[1], - default_initializer=paddle.fluid.initializer.Constant(0.0))#nn.Parameter(tensor([0.])) + self.gamma = self.create_parameter( + shape=[1], + default_initializer=paddle.fluid.initializer.Constant( + 0.0)) #nn.Parameter(tensor([0.])) def forward(self, x): #Notation from https://arxiv.org/pdf/1805.08318.pdf size = x.shape - x = paddle.reshape(x, list(size[:2]) + [-1]) - f,g,h = self.query(x),self.key(x),self.value(x) - - beta = paddle.nn.functional.softmax(paddle.bmm(paddle.transpose(f, [0, 2, 1]), g), axis=1) + x = paddle.reshape(x, list(size[:2]) + [-1]) + f, g, h = self.query(x), self.key(x), self.value(x) + + beta = paddle.nn.functional.softmax(paddle.bmm( + paddle.transpose(f, [0, 2, 1]), g), + axis=1) o = self.gamma * paddle.bmm(h, beta) + x return paddle.reshape(o, size) + def _get_sfs_idxs(sizes): "Get the indexes of the layers where the size of the activation changes." feature_szs = [size[-1] for size in sizes] sfs_idxs = list( - np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0] - ) + np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]) if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs return sfs_idxs @@ -391,5 +465,11 @@ def build_model(): cut = -2 encoder = nn.Sequential(*list(backbone.children())[:cut]) - model = Deoldify(encoder, 3, blur=True, y_range=(-3, 3), norm_type='Spectral', self_attention=True, nf_factor=2) + model = Deoldify(encoder, + 3, + blur=True, + y_range=(-3, 3), + norm_type='Spectral', + self_attention=True, + nf_factor=2) return model diff --git a/applications/DeOldify/predict.py b/applications/DeOldify/predict.py index df77d48de177f8cf2af06d2e21470d8a6e40ceb0..ce637fefb2c49034112a5b59dd657379bd31e8ae 100644 --- a/applications/DeOldify/predict.py +++ b/applications/DeOldify/predict.py @@ -20,35 +20,44 @@ from paddle.utils.download import get_path_from_url parser = argparse.ArgumentParser(description='DeOldify') parser.add_argument('--input', type=str, default='none', help='Input video') parser.add_argument('--output', type=str, default='output', help='output dir') +parser.add_argument('--render_factor', + type=int, + default=32, + help='model inputsize=render_factor*16') parser.add_argument('--weight_path', type=str, - default='none', + default=None, help='Path to the reference image directory') DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' def frames_to_video_ffmpeg(framepath, videopath, r): - ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] cmd = ffmpeg + [ ' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ', ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath ] cmd = ''.join(cmd) - print(cmd) if os.system(cmd) == 0: - print('Video: {} done'.format(videopath)) + pass else: - print('Video: {} error'.format(videopath)) - print('') + print('ffmpeg process video: {} error'.format(videopath)) + sys.stdout.flush() class DeOldifyPredictor(): - def __init__(self, input, output, batch_size=1, weight_path=None): + def __init__(self, + input, + output, + batch_size=1, + weight_path=None, + render_factor=32): self.input = input self.output = os.path.join(output, 'DeOldify') + self.render_factor = render_factor self.model = build_model() if weight_path is None: weight_path = get_path_from_url(DeOldify_weight_url, cur_path) @@ -93,7 +102,7 @@ class DeOldifyPredictor(): def run_single(self, img_path): ori_img = Image.open(img_path).convert('LA').convert('RGB') - img = self.norm(ori_img) + img = self.norm(ori_img, self.render_factor) x = paddle.to_tensor(img[np.newaxis, ...]) out = self.model(x) @@ -139,7 +148,7 @@ class DeOldifyPredictor(): def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): - ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] vid_name = vid_path.split('/')[-1].split('.')[0] out_full_path = os.path.join(outpath, 'frames_input') @@ -158,23 +167,24 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] cmd = ''.join(cmd) - print(cmd) + if os.system(cmd) == 0: - print('Video: {} done'.format(vid_name)) + pass else: - print('Video: {} error'.format(vid_name)) - print('') + print('ffmpeg process video: {} error'.format(vid_name)) + sys.stdout.flush() return out_full_path if __name__ == '__main__': - paddle.enable_imperative() + paddle.disable_static() args = parser.parse_args() predictor = DeOldifyPredictor(args.input, args.output, - weight_path=args.weight_path) + weight_path=args.weight_path, + render_factor=args.render_factor) frames_path, temp_video_path = predictor.run() print('output video path:', temp_video_path) diff --git a/applications/EDVR/data.py b/applications/EDVR/data.py index b05841522c9cca7e95fb7b2f5ef29eeafb9c9d70..ece62cf9fd8bfbee3640cc0bbfbc7567c15dae72 100644 --- a/applications/EDVR/data.py +++ b/applications/EDVR/data.py @@ -2,19 +2,20 @@ import cv2 import numpy as np + def read_img(path, size=None, is_gt=False): """read image by cv2 return: Numpy float32, HWC, BGR, [0,1]""" - # print('debug:', path) img = cv2.imread(path, cv2.IMREAD_UNCHANGED) - + img = img.astype(np.float32) / 255. if img.ndim == 2: img = np.expand_dims(img, axis=2) - + if img.shape[2] > 3: - img = img[:, :, :3] - return img + img = img[:, :, :3] + return img + def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'): """Generate an index list for reading N frames from a sequence of images @@ -62,7 +63,7 @@ def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'): else: add_idx = i return_l.append(add_idx) - # name_b = '{:08d}'.format(crt_i) + # name_b = '{:08d}'.format(crt_i) return return_l @@ -70,7 +71,6 @@ class EDVRDataset: def __init__(self, frame_paths): self.frames = frame_paths - def __getitem__(self, index): indexs = get_test_neighbor_frames(index, 5, len(self.frames)) frame_list = [] @@ -79,7 +79,6 @@ class EDVRDataset: frame_list.append(img) img_LQs = np.stack(frame_list, axis=0) - print('img:', img_LQs.shape) # BGR to RGB, HWC to CHW, numpy to tensor img_LQs = img_LQs[:, :, :, [2, 1, 0]] img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') @@ -87,4 +86,4 @@ class EDVRDataset: return img_LQs, self.frames[index] def __len__(self): - return len(self.frames) \ No newline at end of file + return len(self.frames) diff --git a/applications/EDVR/predict.py b/applications/EDVR/predict.py index 4b888eaebecb3d0bf30ed634f9fb9eaac57b9ff7..11ab8928e877b36ef236fd73e15bf0ef381ded39 100644 --- a/applications/EDVR/predict.py +++ b/applications/EDVR/predict.py @@ -27,6 +27,7 @@ import numpy as np import paddle.fluid as fluid import cv2 +from tqdm import tqdm from data import EDVRDataset from paddle.utils.download import get_path_from_url @@ -52,7 +53,6 @@ def parse_args(): def get_img(pred): - print('pred shape', pred.shape) pred = pred.squeeze() pred = np.clip(pred, a_min=0., a_max=1.0) pred = pred * 255 @@ -72,7 +72,7 @@ def save_img(img, framename): def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): - ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] vid_name = vid_path.split('/')[-1].split('.')[0] out_full_path = os.path.join(outpath, 'frames_input') @@ -91,30 +91,29 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] cmd = ''.join(cmd) - print(cmd) + if os.system(cmd) == 0: - print('Video: {} done'.format(vid_name)) + pass else: - print('Video: {} error'.format(vid_name)) - print('') + print('ffmpeg process video: {} error'.format(vid_name)) + sys.stdout.flush() return out_full_path def frames_to_video_ffmpeg(framepath, videopath, r): - ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] cmd = ffmpeg + [ ' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ', ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath ] cmd = ''.join(cmd) - print(cmd) if os.system(cmd) == 0: - print('Video: {} done'.format(videopath)) + pass else: - print('Video: {} error'.format(videopath)) - print('') + print('ffmpeg process video: {} error'.format(videopath)) + sys.stdout.flush() @@ -164,7 +163,7 @@ class EDVRPredictor: periods = [] cur_time = time.time() - for infer_iter, data in enumerate(dataset): + for infer_iter, data in enumerate(tqdm(dataset)): data_feed_in = [data[0]] infer_outs = self.exe.run( @@ -185,7 +184,7 @@ class EDVRPredictor: period = cur_time - prev_time periods.append(period) - print('Processed {} samples'.format(infer_iter + 1)) + # print('Processed {} samples'.format(infer_iter + 1)) frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') vid_out_path = os.path.join(self.output, '{}_edvr_out.mp4'.format(base_name)) diff --git a/applications/RealSR/predict.py b/applications/RealSR/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..a05bf788e0f11ac73119b0f31c49a3615764c878 --- /dev/null +++ b/applications/RealSR/predict.py @@ -0,0 +1,150 @@ +import os +import sys + +cur_path = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(cur_path) + +import cv2 +import glob +import argparse +import numpy as np +import paddle +import pickle + +from PIL import Image +from tqdm import tqdm +from sr_model import RRDBNet +from paddle.utils.download import get_path_from_url + +parser = argparse.ArgumentParser(description='RealSR') +parser.add_argument('--input', type=str, default='none', help='Input video') +parser.add_argument('--output', type=str, default='output', help='output dir') +parser.add_argument('--weight_path', + type=str, + default=None, + help='Path to the reference image directory') + +RealSR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams' + + +def frames_to_video_ffmpeg(framepath, videopath, r): + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + cmd = ffmpeg + [ + ' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ', + ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath + ] + cmd = ''.join(cmd) + + if os.system(cmd) == 0: + pass + else: + print('ffmpeg process video: {} error'.format(videopath)) + + sys.stdout.flush() + + +class RealSRPredictor(): + def __init__(self, input, output, batch_size=1, weight_path=None): + self.input = input + self.output = os.path.join(output, 'RealSR') + self.model = RRDBNet(3, 3, 64, 23) + if weight_path is None: + weight_path = get_path_from_url(RealSR_weight_url, cur_path) + + state_dict, _ = paddle.load(weight_path) + self.model.load_dict(state_dict) + self.model.eval() + + def norm(self, img): + img = np.array(img).transpose([2, 0, 1]).astype('float32') / 255.0 + return img.astype('float32') + + def denorm(self, img): + img = img.transpose((1, 2, 0)) + return (img * 255).clip(0, 255).astype('uint8') + + def run_single(self, img_path): + ori_img = Image.open(img_path).convert('RGB') + img = self.norm(ori_img) + x = paddle.to_tensor(img[np.newaxis, ...]) + out = self.model(x) + + pred_img = self.denorm(out.numpy()[0]) + pred_img = Image.fromarray(pred_img) + return pred_img + + def run(self): + vid = self.input + base_name = os.path.basename(vid).split('.')[0] + output_path = os.path.join(self.output, base_name) + pred_frame_path = os.path.join(output_path, 'frames_pred') + + if not os.path.exists(output_path): + os.makedirs(output_path) + + if not os.path.exists(pred_frame_path): + os.makedirs(pred_frame_path) + + cap = cv2.VideoCapture(vid) + fps = cap.get(cv2.CAP_PROP_FPS) + + out_path = dump_frames_ffmpeg(vid, output_path) + + frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) + + for frame in tqdm(frames): + pred_img = self.run_single(frame) + + frame_name = os.path.basename(frame) + pred_img.save(os.path.join(pred_frame_path, frame_name)) + + frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') + + vid_out_path = os.path.join(output_path, + '{}_realsr_out.mp4'.format(base_name)) + frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path, + str(int(fps))) + + return frame_pattern_combined, vid_out_path + + +def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): + ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error '] + vid_name = vid_path.split('/')[-1].split('.')[0] + out_full_path = os.path.join(outpath, 'frames_input') + + if not os.path.exists(out_full_path): + os.makedirs(out_full_path) + + # video file name + outformat = out_full_path + '/%08d.png' + + if ss is not None and t is not None and r is not None: + cmd = ffmpeg + [ + ' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ', + ' 0.1 ', ' -start_number ', ' 0 ', outformat + ] + else: + cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] + + cmd = ''.join(cmd) + + if os.system(cmd) == 0: + pass + else: + print('ffmpeg process video: {} error'.format(vid_name)) + + sys.stdout.flush() + return out_full_path + + +if __name__ == '__main__': + paddle.disable_static() + args = parser.parse_args() + + predictor = RealSRPredictor(args.input, + args.output, + weight_path=args.weight_path) + frames_path, temp_video_path = predictor.run() + + print('output video path:', temp_video_path) diff --git a/applications/RealSR/sr_model.py b/applications/RealSR/sr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a730bea00c3b512473785290bc27a0744cf7e0 --- /dev/null +++ b/applications/RealSR/sr_model.py @@ -0,0 +1,76 @@ +import functools +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class ResidualDenseBlock_5C(nn.Layer): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias_attr=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias_attr=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias_attr=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias_attr=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias_attr=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(paddle.concat((x, x1), 1))) + x3 = self.lrelu(self.conv3(paddle.concat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(paddle.concat((x, x1, x2, x3), 1))) + x5 = self.conv5(paddle.concat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Layer): + '''Residual in Residual Dense Block''' + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class RRDBNet(nn.Layer): + def __init__(self, in_nc, out_nc, nf, nb, gc=32): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias_attr=True) + self.RRDB_trunk = make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True) + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias_attr=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu( + self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + fea = self.lrelu( + self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out diff --git a/applications/run.sh b/applications/run.sh index fe7e6553ef0274d69ec5ed26a8da8455358762ba..8dcc8192c0e6b6698b052ccb6cd4abfbd106f4a9 100644 --- a/applications/run.sh +++ b/applications/run.sh @@ -1,13 +1,9 @@ -cd DAIN/pwcnet/correlation_op -# 第一次需要执行 -# bash make.shap -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c 'import paddle; print(paddle.sysconfig.get_lib())'` -export PYTHONPATH=$PYTHONPATH:`pwd` -cd - - +# 模型说明 +# 目前包含DAIN(插帧模型),DeOldify(上色模型),DeepRemaster(去噪与上色模型),EDVR(基于连续帧(视频)超分辨率模型),RealSR(基于图片的超分辨率模型) +# 参数说明 # input 输入视频的路径 # output 输出视频保存的路径 -# proccess_order 使用模型的顺序 +# proccess_order 要使用的模型及顺序 -python tools/main.py \ ---input input.mp4 --output output --proccess_order DAIN DeepRemaster DeOldify EDVR +python tools/video-enhance.py \ +--input input.mp4 --output output --proccess_order DeOldify RealSR diff --git a/applications/tools/video-enhance.py b/applications/tools/video-enhance.py index 9a04a062e3c341c08deac84f728a09950d73b753..04ece7689d33f37c111e5f5acf2c20969f83c2bd 100644 --- a/applications/tools/video-enhance.py +++ b/applications/tools/video-enhance.py @@ -7,53 +7,109 @@ import paddle from DAIN.predict import VideoFrameInterp from DeepRemaster.predict import DeepReasterPredictor from DeOldify.predict import DeOldifyPredictor +from RealSR.predict import RealSRPredictor from EDVR.predict import EDVRPredictor parser = argparse.ArgumentParser(description='Fix video') -parser.add_argument('--input', type=str, default=None, help='Input video') -parser.add_argument('--output', type=str, default='output', help='output dir') -parser.add_argument('--DAIN_weight', type=str, default=None, help='Path to model weight') -parser.add_argument('--DeepRemaster_weight', type=str, default=None, help='Path to model weight') -parser.add_argument('--DeOldify_weight', type=str, default=None, help='Path to model weight') -parser.add_argument('--EDVR_weight', type=str, default=None, help='Path to model weight') +parser.add_argument('--input', type=str, default=None, help='Input video') +parser.add_argument('--output', type=str, default='output', help='output dir') +parser.add_argument('--DAIN_weight', + type=str, + default=None, + help='Path to model weight') +parser.add_argument('--DeepRemaster_weight', + type=str, + default=None, + help='Path to model weight') +parser.add_argument('--DeOldify_weight', + type=str, + default=None, + help='Path to model weight') +parser.add_argument('--RealSR_weight', + type=str, + default=None, + help='Path to model weight') +parser.add_argument('--EDVR_weight', + type=str, + default=None, + help='Path to model weight') # DAIN args -parser.add_argument('--time_step', type=float, default=0.5, help='choose the time steps') +parser.add_argument('--time_step', + type=float, + default=0.5, + help='choose the time steps') # DeepRemaster args -parser.add_argument('--reference_dir', type=str, default=None, help='Path to the reference image directory') -parser.add_argument('--colorization', action='store_true', default=False, help='Remaster with colorization') -parser.add_argument('--mindim', type=int, default=360, help='Length of minimum image edges') -#process order support model name:[DAIN, DeepRemaster, DeOldify, EDVR] -parser.add_argument('--proccess_order', type=str, default='none', nargs='+', help='Process order') - +parser.add_argument('--reference_dir', + type=str, + default=None, + help='Path to the reference image directory') +parser.add_argument('--colorization', + action='store_true', + default=False, + help='Remaster with colorization') +parser.add_argument('--mindim', + type=int, + default=360, + help='Length of minimum image edges') +# DeOldify args +parser.add_argument('--render_factor', + type=int, + default=32, + help='model inputsize=render_factor*16') +#process order support model name:[DAIN, DeepRemaster, DeOldify, RealSR, EDVR] +parser.add_argument('--proccess_order', + type=str, + default='none', + nargs='+', + help='Process order') if __name__ == "__main__": args = parser.parse_args() - + orders = args.proccess_order temp_video_path = None for order in orders: + print('Model {} proccess start..'.format(order)) if temp_video_path is None: temp_video_path = args.input if order == 'DAIN': - predictor = VideoFrameInterp(args.time_step, args.DAIN_weight, - temp_video_path, output_path=args.output) + predictor = VideoFrameInterp(args.time_step, + args.DAIN_weight, + temp_video_path, + output_path=args.output) frames_path, temp_video_path = predictor.run() elif order == 'DeepRemaster': paddle.disable_static() - predictor = DeepReasterPredictor(temp_video_path, args.output, weight_path=args.DeepRemaster_weight, - colorization=args.colorization, reference_dir=args.reference_dir, mindim=args.mindim) + predictor = DeepReasterPredictor( + temp_video_path, + args.output, + weight_path=args.DeepRemaster_weight, + colorization=args.colorization, + reference_dir=args.reference_dir, + mindim=args.mindim) frames_path, temp_video_path = predictor.run() paddle.enable_static() - elif order == 'DeOldify': + elif order == 'DeOldify': paddle.disable_static() - predictor = DeOldifyPredictor(temp_video_path, args.output, weight_path=args.DeOldify_weight) + predictor = DeOldifyPredictor(temp_video_path, + args.output, + weight_path=args.DeOldify_weight) + frames_path, temp_video_path = predictor.run() + paddle.enable_static() + elif order == 'RealSR': + paddle.disable_static() + predictor = RealSRPredictor(temp_video_path, + args.output, + weight_path=args.RealSR_weight) frames_path, temp_video_path = predictor.run() paddle.enable_static() elif order == 'EDVR': - predictor = EDVRPredictor(temp_video_path, args.output, weight_path=args.EDVR_weight) + predictor = EDVRPredictor(temp_video_path, + args.output, + weight_path=args.EDVR_weight) frames_path, temp_video_path = predictor.run() - + print('Model {} output frames path:'.format(order), frames_path) print('Model {} output video path:'.format(order), temp_video_path) - + print('Model {} proccess done!'.format(order))