未验证 提交 1091e63e 编写于 作者: L LielinJiang 提交者: GitHub

Add recurrent vsr predictor (#507)

* add recurrent vsr predictor
* update docs
上级 8aafe376
......@@ -20,6 +20,9 @@ from ppgan.apps import DeepRemasterPredictor
from ppgan.apps import DeOldifyPredictor
from ppgan.apps import RealSRPredictor
from ppgan.apps import EDVRPredictor
from ppgan.apps import PPMSVSRPredictor, BasicVSRPredictor, \
BasiVSRPlusPlusPredictor, IconVSRPredictor, \
PPMSVSRLargePredictor
parser = argparse.ArgumentParser(description='Fix video')
parser.add_argument('--input', type=str, default=None, help='Input video')
......@@ -44,6 +47,26 @@ parser.add_argument('--EDVR_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--PPMSVSR_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--PPMSVSRLarge_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--BasicVSR_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--IconVSR_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--BasiVSRPlusPlus_weight',
type=str,
default=None,
help='Path to model weight')
# DAIN args
parser.add_argument('--time_step',
type=float,
......@@ -75,6 +98,11 @@ parser.add_argument('--render_factor',
type=int,
default=32,
help='model inputsize=render_factor*16')
#vsr input number frames
parser.add_argument('--num_frames',
type=int,
default=10,
help='num frames for recurrent vsr model')
#process order support model name:[DAIN, DeepRemaster, DeOldify, RealSR, EDVR]
parser.add_argument('--process_order',
type=str,
......@@ -121,6 +149,33 @@ if __name__ == "__main__":
elif order == 'EDVR':
predictor = EDVRPredictor(args.output, weight_path=args.EDVR_weight)
frames_path, temp_video_path = predictor.run(temp_video_path)
elif order == 'PPMSVSR':
predictor = PPMSVSRPredictor(args.output,
weight_path=args.PPMSVSR_weight,
num_frames=args.num_frames)
frames_path, temp_video_path = predictor.run(temp_video_path)
elif order == 'PPMSVSRLarge':
predictor = PPMSVSRLargePredictor(
args.output,
weight_path=args.PPMSVSRLarge_weight,
num_frames=args.num_frames)
frames_path, temp_video_path = predictor.run(temp_video_path)
elif order == 'BasicVSR':
predictor = BasicVSRPredictor(args.output,
weight_path=args.BasicVSR_weight,
num_frames=args.num_frames)
frames_path, temp_video_path = predictor.run(temp_video_path)
elif order == 'IconVSR':
predictor = IconVSRPredictor(args.output,
weight_path=args.IconVSR_weight,
num_frames=args.num_frames)
frames_path, temp_video_path = predictor.run(temp_video_path)
elif order == 'BasiVSRPlusPlus':
predictor = BasiVSRPlusPlusPredictor(
args.output,
weight_path=args.BasiVSRPlusPlus_weight,
num_frames=args.num_frames)
frames_path, temp_video_path = predictor.run(temp_video_path)
print('Model {} output frames path:'.format(order), frames_path)
print('Model {} output video path:'.format(order), temp_video_path)
......
......@@ -112,4 +112,65 @@ ppgan.apps.EDVRPredictor(output='output', weight_path=None)
#### Parameters
- `output (str, Optional)`: path of your output, default: `output`.
- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`.
\ No newline at end of file
- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`.
### Video super-resolution model -- BasicVSRPredictor & IconVSRPredictor
BasicVSR is a generic and efficient baseline for VSR. With minimal redesigns of existing components including optical flow and residual blocks, it outperforms existing state of the arts with high efficiency. BasicVSR adopts a typical bidirectional recurrent network. The upsampling module U contains multiple pixel-shuffle and convolutions. The red and blue colors represent the backward and forward propagations, respectively. The propagation branches contain only generic components. S, W, and R refer to the flow estimation module, spatial warping module, and residual blocks, respectively.
![](../../imgs/basicvsr_arch.jpg)
```
ppgan.apps.BasiVSRPredictor(output='output', weight_path=None, num_frames=10)
ppgan.apps.IconVSRPredictor(output='output', weight_path=None, num_frames=10)
```
#### Parameters
- `output (str, Optional)`: path of your output, default: `output`.
- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`.
- `num_frames (10, Optional)`: the number of video frames input at a time. Default: `10`.
### Video super-resolution model -- BasicVSRPlusPlusPredictor
BasicVSR++ consists of two effective modifications for improving propagation and alignment. The proposed second-order grid propagation and flow-guided deformable alignment allow BasicVSR++ to significantly outperform existing state of the arts with comparable runtime. BasicVSR++ won 3 champions and 1 runner-up in NTIRE 2021 Video Restoration and Enhancement Challenge.
![](../../imgs/basicvsr++_arch.jpg)
```
ppgan.apps.BasiVSRPlusPlusPredictor(output='output', weight_path=None, num_frames=10)
```
#### Parameters
- `output (str, Optional)`: path of your output, default: `output`.
- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`.
- `num_frames (10, Optional)`: the number of video frames input at a time. Default: `10`.
### Video super-resolution model -- BasicVSRPlusPlusPredictor
BasicVSR++ consists of two effective modifications for improving propagation and alignment. The proposed second-order grid propagation and flow-guided deformable alignment allow BasicVSR++ to significantly outperform existing state of the arts with comparable runtime. BasicVSR++ won 3 champions and 1 runner-up in NTIRE 2021 Video Restoration and Enhancement Challenge.
![](../../imgs/basicvsr++_arch.jpg)
```
ppgan.apps.BasiVSRPlusPlusPredictor(output='output', weight_path=None, num_frames=10)
```
#### Parameters
- `output (str, Optional)`: path of your output, default: `output`.
- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`.
- `num_frames (10, Optional)`: the number of video frames input at a time. Default: `10`.
### Video super-resolution model -- PPMSVSRPredictor
PP-MSVSR proposes local fusion module, auxiliary loss and re-align module to refine the enhanced result progressively.
![](../../imgs/msvsr_arch.jpg)
```
ppgan.apps.PPMSVSRPredictor(output='output', weight_path=None, num_frames=10)
ppgan.apps.PPMSVSRLargePredictor(output='output', weight_path=None, num_frames=10)
```
#### Parameters
- `output (str, Optional)`: path of your output, default: `output`.
- `weight_path (None, Optional)`: path of your model weight. If it is not set, the default weight will be downloaded from the cloud to the local. Default: `None`.
- `num_frames (10, Optional)`: the number of video frames input at a time. Default: `10`.
......@@ -31,3 +31,6 @@ from .wav2lip_predictor import Wav2LipPredictor
from .mpr_predictor import MPRPredictor
from .lapstyle_predictor import LapStylePredictor
from .photopen_predictor import PhotoPenPredictor
from .recurrent_vsr_predictor import (PPMSVSRPredictor, BasicVSRPredictor, \
BasiVSRPlusPlusPredictor, IconVSRPredictor, \
PPMSVSRLargePredictor)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import cv2
import time
import glob
import numpy as np
from tqdm import tqdm
import paddle
from paddle.io import Dataset, DataLoader
from ppgan.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames
from ppgan.models.generators import BasicVSRNet, IconVSR, BasicVSRPlusPlus, MSVSR
from .base_predictor import BasePredictor
from .edvr_predictor import get_img, read_img, save_img
BasicVSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams'
IconVSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams'
BasicVSR_PP_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams'
PP_MSVSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/PP-MSVSR_reds_x4.pdparams'
PP_MSVSR_BD_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/PP-MSVSR_vimeo90k_x4.pdparams'
PP_MSVSR_L_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/PP-MSVSR-L_reds_x4.pdparams'
class RecurrentDataset(Dataset):
def __init__(self, frames_path, num_frames=30):
self.frames_path = frames_path
if num_frames is not None:
self.num_frames = num_frames
else:
self.num_frames = len(self.frames_path)
if len(frames_path) % self.num_frames == 0:
self.size = len(frames_path) // self.num_frames
else:
self.size = len(frames_path) // self.num_frames + 1
def __getitem__(self, index):
indexs = list(
range(index * self.num_frames, (index + 1) * self.num_frames))
frame_list = []
frames_path = []
for i in indexs:
if i >= len(self.frames_path):
break
frames_path.append(self.frames_path[i])
img = read_img(self.frames_path[i])
frame_list.append(img)
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
return img_LQs, frames_path
def __len__(self):
return self.size
class BasicVSRPredictor(BasePredictor):
def __init__(self, output='output', weight_path=None, num_frames=10):
self.input = input
self.name = 'BasiVSR'
self.num_frames = num_frames
self.output = os.path.join(output, self.name)
self.model = BasicVSRNet()
if weight_path is None:
weight_path = get_path_from_url(BasicVSR_WEIGHT_URL)
self.model.set_dict(paddle.load(weight_path)['generator'])
self.model.eval()
def run(self, video_path):
vid = video_path
base_name = os.path.basename(vid).split('.')[0]
output_path = os.path.join(self.output, base_name)
pred_frame_path = os.path.join(output_path, 'frames_pred')
if not os.path.exists(output_path):
os.makedirs(output_path)
if not os.path.exists(pred_frame_path):
os.makedirs(pred_frame_path)
cap = cv2.VideoCapture(vid)
fps = cap.get(cv2.CAP_PROP_FPS)
out_path = video2frames(vid, output_path)
frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
test_dataset = RecurrentDataset(frames, num_frames=self.num_frames)
dataset = DataLoader(test_dataset, batch_size=1, num_workers=2)
periods = []
cur_time = time.time()
for infer_iter, data in enumerate(tqdm(dataset)):
data_feed_in = paddle.to_tensor(data[0])
with paddle.no_grad():
outs = self.model(data_feed_in)
if isinstance(outs, (list, tuple)):
outs = outs[-1]
outs = outs[0].numpy()
infer_result_list = [outs[i, :, :, :] for i in range(outs.shape[0])]
frames_path = data[1]
for i in range(len(infer_result_list)):
img_i = get_img(infer_result_list[i])
save_img(
img_i,
os.path.join(pred_frame_path,
os.path.basename(frames_path[i][0])))
prev_time = cur_time
cur_time = time.time()
period = cur_time - prev_time
periods.append(period)
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(
self.output, '{}_{}_out.mp4'.format(base_name, self.name))
frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
return frame_pattern_combined, vid_out_path
class IconVSRPredictor(BasicVSRPredictor):
def __init__(self, output='output', weight_path=None, num_frames=10):
self.input = input
self.name = 'IconVSR'
self.output = os.path.join(output, self.name)
self.num_frames = num_frames
self.model = IconVSR()
if weight_path is None:
weight_path = get_path_from_url(IconVSR_WEIGHT_URL)
self.model.set_dict(paddle.load(weight_path)['generator'])
self.model.eval()
class BasiVSRPlusPlusPredictor(BasicVSRPredictor):
def __init__(self, output='output', weight_path=None, num_frames=10):
self.input = input
self.name = 'BasiVSR_PP'
self.output = os.path.join(output, self.name)
self.num_frames = num_frames
self.model = BasicVSRPlusPlus()
if weight_path is None:
weight_path = get_path_from_url(BasicVSR_PP_WEIGHT_URL)
self.model.set_dict(paddle.load(weight_path)['generator'])
self.model.eval()
class PPMSVSRPredictor(BasicVSRPredictor):
def __init__(self, output='output', weight_path=None, num_frames=10):
self.input = input
self.name = 'PPMSVSR'
self.output = os.path.join(output, self.name)
self.num_frames = num_frames
self.model = MSVSR()
if weight_path is None:
weight_path = get_path_from_url(PP_MSVSR_WEIGHT_URL)
self.model.set_dict(paddle.load(weight_path)['generator'])
self.model.eval()
class PPMSVSRLargePredictor(BasicVSRPredictor):
def __init__(self, output='output', weight_path=None, num_frames=10):
self.input = input
self.name = 'PPMSVSR-L'
self.output = os.path.join(output, self.name)
self.num_frames = num_frames
self.model = MSVSR(mid_channels=64,
num_init_blocks=5,
num_blocks=7,
num_reconstruction_blocks=5,
only_last=False,
use_tiny_spynet=False,
deform_groups=8,
aux_reconstruction_blocks=2)
if weight_path is None:
weight_path = get_path_from_url(PP_MSVSR_L_WEIGHT_URL)
self.model.set_dict(paddle.load(weight_path)['generator'])
self.model.eval()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册