diff --git a/applications/tools/video-enhance.py b/applications/tools/video-enhance.py index 343388a3090761a15fa4256c2713c1c8cf357d1b..e7476a710eb17a8a8ad029288db1a1416f2285db 100644 --- a/applications/tools/video-enhance.py +++ b/applications/tools/video-enhance.py @@ -20,6 +20,9 @@ from ppgan.apps import DeepRemasterPredictor from ppgan.apps import DeOldifyPredictor from ppgan.apps import RealSRPredictor from ppgan.apps import EDVRPredictor +from ppgan.apps import PPMSVSRPredictor, BasicVSRPredictor, \ + BasiVSRPlusPlusPredictor, IconVSRPredictor, \ + PPMSVSRLargePredictor parser = argparse.ArgumentParser(description='Fix video') parser.add_argument('--input', type=str, default=None, help='Input video') @@ -44,6 +47,26 @@ parser.add_argument('--EDVR_weight', type=str, default=None, help='Path to model weight') +parser.add_argument('--PPMSVSR_weight', + type=str, + default=None, + help='Path to model weight') +parser.add_argument('--PPMSVSRLarge_weight', + type=str, + default=None, + help='Path to model weight') +parser.add_argument('--BasicVSR_weight', + type=str, + default=None, + help='Path to model weight') +parser.add_argument('--IconVSR_weight', + type=str, + default=None, + help='Path to model weight') +parser.add_argument('--BasiVSRPlusPlus_weight', + type=str, + default=None, + help='Path to model weight') # DAIN args parser.add_argument('--time_step', type=float, @@ -75,6 +98,11 @@ parser.add_argument('--render_factor', type=int, default=32, help='model inputsize=render_factor*16') +#vsr input number frames +parser.add_argument('--num_frames', + type=int, + default=10, + help='num frames for recurrent vsr model') #process order support model name:[DAIN, DeepRemaster, DeOldify, RealSR, EDVR] parser.add_argument('--process_order', type=str, @@ -121,6 +149,33 @@ if __name__ == "__main__": elif order == 'EDVR': predictor = EDVRPredictor(args.output, weight_path=args.EDVR_weight) frames_path, temp_video_path = predictor.run(temp_video_path) + elif order == 'PPMSVSR': + predictor = PPMSVSRPredictor(args.output, + weight_path=args.PPMSVSR_weight, + num_frames=args.num_frames) + frames_path, temp_video_path = predictor.run(temp_video_path) + elif order == 'PPMSVSRLarge': + predictor = PPMSVSRLargePredictor( + args.output, + weight_path=args.PPMSVSRLarge_weight, + num_frames=args.num_frames) + frames_path, temp_video_path = predictor.run(temp_video_path) + elif order == 'BasicVSR': + predictor = BasicVSRPredictor(args.output, + weight_path=args.BasicVSR_weight, + num_frames=args.num_frames) + frames_path, temp_video_path = predictor.run(temp_video_path) + elif order == 'IconVSR': + predictor = IconVSRPredictor(args.output, + weight_path=args.IconVSR_weight, + num_frames=args.num_frames) + frames_path, temp_video_path = predictor.run(temp_video_path) + elif order == 'BasiVSRPlusPlus': + predictor = BasiVSRPlusPlusPredictor( + args.output, + weight_path=args.BasiVSRPlusPlus_weight, + num_frames=args.num_frames) + frames_path, temp_video_path = predictor.run(temp_video_path) print('Model {} output frames path:'.format(order), frames_path) print('Model {} output video path:'.format(order), temp_video_path) diff --git a/docs/en_US/tutorials/video_restore.md b/docs/en_US/tutorials/video_restore.md index 2f8f4fce79d9961556a45189843cf4b3d2f27ca5..8365c8fd000e872d78a30648f890e8694ea50903 100644 --- a/docs/en_US/tutorials/video_restore.md +++ b/docs/en_US/tutorials/video_restore.md @@ -112,4 +112,65 @@ ppgan.apps.EDVRPredictor(output='output', weight_path=None) #### Parameters - `output (str, Optional)`: path of your output, default: `output`. -- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`. \ No newline at end of file +- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`. + +### Video super-resolution model -- BasicVSRPredictor & IconVSRPredictor +BasicVSR is a generic and efficient baseline for VSR. With minimal redesigns of existing components including optical flow and residual blocks, it outperforms existing state of the arts with high efficiency. BasicVSR adopts a typical bidirectional recurrent network. The upsampling module U contains multiple pixel-shuffle and convolutions. The red and blue colors represent the backward and forward propagations, respectively. The propagation branches contain only generic components. S, W, and R refer to the flow estimation module, spatial warping module, and residual blocks, respectively. + +![](../../imgs/basicvsr_arch.jpg) + +``` +ppgan.apps.BasiVSRPredictor(output='output', weight_path=None, num_frames=10) +ppgan.apps.IconVSRPredictor(output='output', weight_path=None, num_frames=10) +``` +#### Parameters + +- `output (str, Optional)`: path of your output, default: `output`. +- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`. +- `num_frames (10, Optional)`: the number of video frames input at a time. Default: `10`. + + +### Video super-resolution model -- BasicVSRPlusPlusPredictor +BasicVSR++ consists of two effective modifications for improving propagation and alignment. The proposed second-order grid propagation and flow-guided deformable alignment allow BasicVSR++ to significantly outperform existing state of the arts with comparable runtime. BasicVSR++ won 3 champions and 1 runner-up in NTIRE 2021 Video Restoration and Enhancement Challenge. + +![](../../imgs/basicvsr++_arch.jpg) + +``` +ppgan.apps.BasiVSRPlusPlusPredictor(output='output', weight_path=None, num_frames=10) +``` +#### Parameters + +- `output (str, Optional)`: path of your output, default: `output`. +- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`. +- `num_frames (10, Optional)`: the number of video frames input at a time. Default: `10`. + + +### Video super-resolution model -- BasicVSRPlusPlusPredictor +BasicVSR++ consists of two effective modifications for improving propagation and alignment. The proposed second-order grid propagation and flow-guided deformable alignment allow BasicVSR++ to significantly outperform existing state of the arts with comparable runtime. BasicVSR++ won 3 champions and 1 runner-up in NTIRE 2021 Video Restoration and Enhancement Challenge. + +![](../../imgs/basicvsr++_arch.jpg) + +``` +ppgan.apps.BasiVSRPlusPlusPredictor(output='output', weight_path=None, num_frames=10) +``` +#### Parameters + +- `output (str, Optional)`: path of your output, default: `output`. +- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`. +- `num_frames (10, Optional)`: the number of video frames input at a time. Default: `10`. + + +### Video super-resolution model -- PPMSVSRPredictor +PP-MSVSR proposes local fusion module, auxiliary loss and re-align module to refine the enhanced result progressively. + +![](../../imgs/msvsr_arch.jpg) + +``` +ppgan.apps.PPMSVSRPredictor(output='output', weight_path=None, num_frames=10) +ppgan.apps.PPMSVSRLargePredictor(output='output', weight_path=None, num_frames=10) +``` +#### Parameters + +- `output (str, Optional)`: path of your output, default: `output`. +- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`. +- `num_frames (10, Optional)`: the number of video frames input at a time. Default: `10`. diff --git a/ppgan/apps/__init__.py b/ppgan/apps/__init__.py index 65d731394d6dc50e8709e1f1f8ac369ed8290622..32ee395c25139616045542c58fd3a4820a6fc9b1 100644 --- a/ppgan/apps/__init__.py +++ b/ppgan/apps/__init__.py @@ -31,3 +31,6 @@ from .wav2lip_predictor import Wav2LipPredictor from .mpr_predictor import MPRPredictor from .lapstyle_predictor import LapStylePredictor from .photopen_predictor import PhotoPenPredictor +from .recurrent_vsr_predictor import (PPMSVSRPredictor, BasicVSRPredictor, \ + BasiVSRPlusPlusPredictor, IconVSRPredictor, \ + PPMSVSRLargePredictor) diff --git a/ppgan/apps/recurrent_vsr_predictor.py b/ppgan/apps/recurrent_vsr_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd3e5082989e42b88bb33c3b6506c70d32d74ad --- /dev/null +++ b/ppgan/apps/recurrent_vsr_predictor.py @@ -0,0 +1,203 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import cv2 +import time +import glob +import numpy as np +from tqdm import tqdm + +import paddle +from paddle.io import Dataset, DataLoader + +from ppgan.utils.download import get_path_from_url +from ppgan.utils.video import frames2video, video2frames +from ppgan.models.generators import BasicVSRNet, IconVSR, BasicVSRPlusPlus, MSVSR +from .base_predictor import BasePredictor +from .edvr_predictor import get_img, read_img, save_img + +BasicVSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams' +IconVSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams' +BasicVSR_PP_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams' +PP_MSVSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/PP-MSVSR_reds_x4.pdparams' +PP_MSVSR_BD_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/PP-MSVSR_vimeo90k_x4.pdparams' +PP_MSVSR_L_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/PP-MSVSR-L_reds_x4.pdparams' + + +class RecurrentDataset(Dataset): + def __init__(self, frames_path, num_frames=30): + self.frames_path = frames_path + + if num_frames is not None: + self.num_frames = num_frames + else: + self.num_frames = len(self.frames_path) + + if len(frames_path) % self.num_frames == 0: + self.size = len(frames_path) // self.num_frames + else: + self.size = len(frames_path) // self.num_frames + 1 + + def __getitem__(self, index): + indexs = list( + range(index * self.num_frames, (index + 1) * self.num_frames)) + frame_list = [] + frames_path = [] + for i in indexs: + if i >= len(self.frames_path): + break + + frames_path.append(self.frames_path[i]) + img = read_img(self.frames_path[i]) + frame_list.append(img) + + img_LQs = np.stack(frame_list, axis=0) + # BGR to RGB, HWC to CHW, numpy to tensor + img_LQs = img_LQs[:, :, :, [2, 1, 0]] + img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') + + return img_LQs, frames_path + + def __len__(self): + return self.size + + +class BasicVSRPredictor(BasePredictor): + def __init__(self, output='output', weight_path=None, num_frames=10): + self.input = input + self.name = 'BasiVSR' + self.num_frames = num_frames + self.output = os.path.join(output, self.name) + self.model = BasicVSRNet() + if weight_path is None: + weight_path = get_path_from_url(BasicVSR_WEIGHT_URL) + self.model.set_dict(paddle.load(weight_path)['generator']) + self.model.eval() + + def run(self, video_path): + vid = video_path + base_name = os.path.basename(vid).split('.')[0] + output_path = os.path.join(self.output, base_name) + pred_frame_path = os.path.join(output_path, 'frames_pred') + + if not os.path.exists(output_path): + os.makedirs(output_path) + + if not os.path.exists(pred_frame_path): + os.makedirs(pred_frame_path) + + cap = cv2.VideoCapture(vid) + fps = cap.get(cv2.CAP_PROP_FPS) + + out_path = video2frames(vid, output_path) + + frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) + + test_dataset = RecurrentDataset(frames, num_frames=self.num_frames) + dataset = DataLoader(test_dataset, batch_size=1, num_workers=2) + + periods = [] + cur_time = time.time() + for infer_iter, data in enumerate(tqdm(dataset)): + data_feed_in = paddle.to_tensor(data[0]) + with paddle.no_grad(): + outs = self.model(data_feed_in) + + if isinstance(outs, (list, tuple)): + outs = outs[-1] + + outs = outs[0].numpy() + + infer_result_list = [outs[i, :, :, :] for i in range(outs.shape[0])] + + frames_path = data[1] + + for i in range(len(infer_result_list)): + img_i = get_img(infer_result_list[i]) + save_img( + img_i, + os.path.join(pred_frame_path, + os.path.basename(frames_path[i][0]))) + + prev_time = cur_time + cur_time = time.time() + period = cur_time - prev_time + periods.append(period) + + frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') + vid_out_path = os.path.join( + self.output, '{}_{}_out.mp4'.format(base_name, self.name)) + frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) + + return frame_pattern_combined, vid_out_path + + +class IconVSRPredictor(BasicVSRPredictor): + def __init__(self, output='output', weight_path=None, num_frames=10): + self.input = input + self.name = 'IconVSR' + self.output = os.path.join(output, self.name) + self.num_frames = num_frames + self.model = IconVSR() + if weight_path is None: + weight_path = get_path_from_url(IconVSR_WEIGHT_URL) + self.model.set_dict(paddle.load(weight_path)['generator']) + self.model.eval() + + +class BasiVSRPlusPlusPredictor(BasicVSRPredictor): + def __init__(self, output='output', weight_path=None, num_frames=10): + self.input = input + self.name = 'BasiVSR_PP' + self.output = os.path.join(output, self.name) + self.num_frames = num_frames + self.model = BasicVSRPlusPlus() + if weight_path is None: + weight_path = get_path_from_url(BasicVSR_PP_WEIGHT_URL) + self.model.set_dict(paddle.load(weight_path)['generator']) + self.model.eval() + + +class PPMSVSRPredictor(BasicVSRPredictor): + def __init__(self, output='output', weight_path=None, num_frames=10): + self.input = input + self.name = 'PPMSVSR' + self.output = os.path.join(output, self.name) + self.num_frames = num_frames + self.model = MSVSR() + if weight_path is None: + weight_path = get_path_from_url(PP_MSVSR_WEIGHT_URL) + self.model.set_dict(paddle.load(weight_path)['generator']) + self.model.eval() + + +class PPMSVSRLargePredictor(BasicVSRPredictor): + def __init__(self, output='output', weight_path=None, num_frames=10): + self.input = input + self.name = 'PPMSVSR-L' + self.output = os.path.join(output, self.name) + self.num_frames = num_frames + self.model = MSVSR(mid_channels=64, + num_init_blocks=5, + num_blocks=7, + num_reconstruction_blocks=5, + only_last=False, + use_tiny_spynet=False, + deform_groups=8, + aux_reconstruction_blocks=2) + if weight_path is None: + weight_path = get_path_from_url(PP_MSVSR_L_WEIGHT_URL) + self.model.set_dict(paddle.load(weight_path)['generator']) + self.model.eval()