提交 9ed134f8 编写于 作者: L LielinJiang

Merge branch 'master' of https://github.com/PaddlePaddle/PaddleGAN into sr

......@@ -8,6 +8,7 @@ import time
import glob
import numpy as np
from imageio import imread, imsave
from tqdm import tqdm
import cv2
import paddle.fluid as fluid
......@@ -175,8 +176,7 @@ class VideoFrameInterp(object):
if not os.path.exists(os.path.join(frame_path_combined, vidname)):
os.makedirs(os.path.join(frame_path_combined, vidname))
for i in range(frame_num - 1):
print(frames[i])
for i in tqdm(range(frame_num - 1)):
first = frames[i]
second = frames[i + 1]
......@@ -208,12 +208,10 @@ class VideoFrameInterp(object):
assert (X0.shape[1] == X1.shape[1])
assert (X0.shape[2] == X1.shape[2])
print("size before padding ", X0.shape)
X0 = np.pad(X0, ((0,0), (padding_top, padding_bottom), \
(padding_left, padding_right)), mode='edge')
X1 = np.pad(X1, ((0,0), (padding_top, padding_bottom), \
(padding_left, padding_right)), mode='edge')
print("size after padding ", X0.shape)
X0 = np.expand_dims(X0, axis=0)
X1 = np.expand_dims(X1, axis=0)
......@@ -233,8 +231,6 @@ class VideoFrameInterp(object):
proc_timer.update(time.time() - proc_end)
tot_timer.update(time.time() - end)
end = time.time()
print("*********** current image process time \t " +
str(time.time() - proc_end) + "s *********")
y_ = [
np.transpose(
......
......@@ -17,8 +17,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, Conv2DTranspose
from .correlation_op.correlation import correlation
from paddle.fluid.contrib import correlation
__all__ = ['pwc_dc_net']
......
......@@ -22,7 +22,7 @@ class AverageMeter(object):
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error ']
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, vid_name)
......@@ -55,30 +55,29 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(vid_name))
pass
else:
print('Video: {} error'.format(vid_name))
print('')
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error ']
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(videopath))
pass
else:
print('Video: {} error'.format(videopath))
print('')
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
......@@ -99,7 +98,8 @@ def combine_frames(input, interpolated, combined, num_frames):
for k in range(num_frames):
src = frames2[i * num_frames + k]
dst = os.path.join(
combined, '{:08d}.png'.format(i * (num_frames + 1) + k + 1))
combined,
'{:08d}.png'.format(i * (num_frames + 1) + k + 1))
shutil.copy2(src, dst)
except Exception as e:
print(e)
......
......@@ -3,14 +3,16 @@ import numpy as np
import paddle
import paddle.nn as nn
def is_listy(x):
return isinstance(x, (tuple,list))
return isinstance(x, (tuple, list))
class Hook():
"Create a hook on `m` with `hook_func`."
def __init__(self, m, hook_func, is_forward=True, detach=True):
self.hook_func,self.detach,self.stored = hook_func,detach,None
self.hook_func, self.detach, self.stored = hook_func, detach, None
f = m.register_forward_post_hook if is_forward else m.register_backward_hook
self.hook = f(self.hook_fn)
self.removed = False
......@@ -18,64 +20,90 @@ class Hook():
def hook_fn(self, module, input, output):
"Applies `hook_func` to `module`, `input`, `output`."
if self.detach:
input = (o.detach() for o in input ) if is_listy(input ) else input.detach()
output = (o.detach() for o in output) if is_listy(output) else output.detach()
input = (o.detach()
for o in input) if is_listy(input) else input.detach()
output = (o.detach()
for o in output) if is_listy(output) else output.detach()
self.stored = self.hook_func(module, input, output)
def remove(self):
"Remove the hook from the model."
if not self.removed:
self.hook.remove()
self.removed=True
self.removed = True
def __enter__(self, *args):
return self
def __exit__(self, *args):
self.remove()
def __enter__(self, *args): return self
def __exit__(self, *args): self.remove()
class Hooks():
"Create several hooks on the modules in `ms` with `hook_func`."
def __init__(self, ms, hook_func, is_forward=True, detach=True):
self.hooks = []
try:
for m in ms:
self.hooks.append(Hook(m, hook_func, is_forward, detach))
except Exception as e:
print(e)
pass
def __getitem__(self, i: int) -> Hook:
return self.hooks[i]
def __len__(self) -> int:
return len(self.hooks)
def __iter__(self):
return iter(self.hooks)
def __getitem__(self,i:int)->Hook: return self.hooks[i]
def __len__(self)->int: return len(self.hooks)
def __iter__(self): return iter(self.hooks)
@property
def stored(self): return [o.stored for o in self]
def stored(self):
return [o.stored for o in self]
def remove(self):
"Remove the hooks from the model."
for h in self.hooks: h.remove()
for h in self.hooks:
h.remove()
def __enter__(self, *args): return self
def __exit__ (self, *args): self.remove()
def __enter__(self, *args):
return self
def _hook_inner(m,i,o): return o if isinstance(o, paddle.framework.Variable) else o if is_listy(o) else list(o)
def __exit__(self, *args):
self.remove()
def hook_output (module, detach=True, grad=False):
def _hook_inner(m, i, o):
return o if isinstance(
o, paddle.framework.Variable) else o if is_listy(o) else list(o)
def hook_output(module, detach=True, grad=False):
"Return a `Hook` that stores activations of `module` in `self.stored`"
return Hook(module, _hook_inner, detach=detach, is_forward=not grad)
def hook_outputs(modules, detach=True, grad=False):
"Return `Hooks` that store activations of all `modules` in `self.stored`"
return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)
def model_sizes(m, size=(64,64)):
def model_sizes(m, size=(64, 64)):
"Pass a dummy input through the model `m` to get the various sizes of activations."
with hook_outputs(m) as hooks:
x = dummy_eval(m, size)
return [o.stored.shape for o in hooks]
def dummy_eval(m, size=(64,64)):
def dummy_eval(m, size=(64, 64)):
"Pass a `dummy_batch` in evaluation mode in `m` with `size`."
m.eval()
return m(dummy_batch(size))
def dummy_batch(size=(64,64), ch_in=3):
def dummy_batch(size=(64, 64), ch_in=3):
"Create a dummy batch to go through `m` with `size`."
arr = np.random.rand(1, ch_in, *size).astype('float32') * 2 - 1
return paddle.to_tensor(arr)
此差异已折叠。
......@@ -20,35 +20,44 @@ from paddle.utils.download import get_path_from_url
parser = argparse.ArgumentParser(description='DeOldify')
parser.add_argument('--input', type=str, default='none', help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--render_factor',
type=int,
default=32,
help='model inputsize=render_factor*16')
parser.add_argument('--weight_path',
type=str,
default='none',
default=None,
help='Path to the reference image directory')
DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error ']
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(videopath))
pass
else:
print('Video: {} error'.format(videopath))
print('')
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
class DeOldifyPredictor():
def __init__(self, input, output, batch_size=1, weight_path=None):
def __init__(self,
input,
output,
batch_size=1,
weight_path=None,
render_factor=32):
self.input = input
self.output = os.path.join(output, 'DeOldify')
self.render_factor = render_factor
self.model = build_model()
if weight_path is None:
weight_path = get_path_from_url(DeOldify_weight_url, cur_path)
......@@ -93,7 +102,7 @@ class DeOldifyPredictor():
def run_single(self, img_path):
ori_img = Image.open(img_path).convert('LA').convert('RGB')
img = self.norm(ori_img)
img = self.norm(ori_img, self.render_factor)
x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x)
......@@ -139,7 +148,7 @@ class DeOldifyPredictor():
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error ']
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, 'frames_input')
......@@ -158,23 +167,24 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(vid_name))
pass
else:
print('Video: {} error'.format(vid_name))
print('')
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
if __name__ == '__main__':
paddle.enable_imperative()
paddle.disable_static()
args = parser.parse_args()
predictor = DeOldifyPredictor(args.input,
args.output,
weight_path=args.weight_path)
weight_path=args.weight_path,
render_factor=args.render_factor)
frames_path, temp_video_path = predictor.run()
print('output video path:', temp_video_path)
......@@ -2,19 +2,20 @@ import cv2
import numpy as np
def read_img(path, size=None, is_gt=False):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]"""
# print('debug:', path)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
if img.shape[2] > 3:
img = img[:, :, :3]
return img
img = img[:, :, :3]
return img
def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'):
"""Generate an index list for reading N frames from a sequence of images
......@@ -62,7 +63,7 @@ def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'):
else:
add_idx = i
return_l.append(add_idx)
# name_b = '{:08d}'.format(crt_i)
# name_b = '{:08d}'.format(crt_i)
return return_l
......@@ -70,7 +71,6 @@ class EDVRDataset:
def __init__(self, frame_paths):
self.frames = frame_paths
def __getitem__(self, index):
indexs = get_test_neighbor_frames(index, 5, len(self.frames))
frame_list = []
......@@ -79,7 +79,6 @@ class EDVRDataset:
frame_list.append(img)
img_LQs = np.stack(frame_list, axis=0)
print('img:', img_LQs.shape)
# BGR to RGB, HWC to CHW, numpy to tensor
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
......@@ -87,4 +86,4 @@ class EDVRDataset:
return img_LQs, self.frames[index]
def __len__(self):
return len(self.frames)
\ No newline at end of file
return len(self.frames)
......@@ -27,6 +27,7 @@ import numpy as np
import paddle.fluid as fluid
import cv2
from tqdm import tqdm
from data import EDVRDataset
from paddle.utils.download import get_path_from_url
......@@ -52,7 +53,6 @@ def parse_args():
def get_img(pred):
print('pred shape', pred.shape)
pred = pred.squeeze()
pred = np.clip(pred, a_min=0., a_max=1.0)
pred = pred * 255
......@@ -72,7 +72,7 @@ def save_img(img, framename):
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error ']
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, 'frames_input')
......@@ -91,30 +91,29 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(vid_name))
pass
else:
print('Video: {} error'.format(vid_name))
print('')
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error ']
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(videopath))
pass
else:
print('Video: {} error'.format(videopath))
print('')
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
......@@ -164,7 +163,7 @@ class EDVRPredictor:
periods = []
cur_time = time.time()
for infer_iter, data in enumerate(dataset):
for infer_iter, data in enumerate(tqdm(dataset)):
data_feed_in = [data[0]]
infer_outs = self.exe.run(
......@@ -185,7 +184,7 @@ class EDVRPredictor:
period = cur_time - prev_time
periods.append(period)
print('Processed {} samples'.format(infer_iter + 1))
# print('Processed {} samples'.format(infer_iter + 1))
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(self.output,
'{}_edvr_out.mp4'.format(base_name))
......
import os
import sys
cur_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(cur_path)
import cv2
import glob
import argparse
import numpy as np
import paddle
import pickle
from PIL import Image
from tqdm import tqdm
from sr_model import RRDBNet
from paddle.utils.download import get_path_from_url
parser = argparse.ArgumentParser(description='RealSR')
parser.add_argument('--input', type=str, default='none', help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--weight_path',
type=str,
default=None,
help='Path to the reference image directory')
RealSR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams'
def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
class RealSRPredictor():
def __init__(self, input, output, batch_size=1, weight_path=None):
self.input = input
self.output = os.path.join(output, 'RealSR')
self.model = RRDBNet(3, 3, 64, 23)
if weight_path is None:
weight_path = get_path_from_url(RealSR_weight_url, cur_path)
state_dict, _ = paddle.load(weight_path)
self.model.load_dict(state_dict)
self.model.eval()
def norm(self, img):
img = np.array(img).transpose([2, 0, 1]).astype('float32') / 255.0
return img.astype('float32')
def denorm(self, img):
img = img.transpose((1, 2, 0))
return (img * 255).clip(0, 255).astype('uint8')
def run_single(self, img_path):
ori_img = Image.open(img_path).convert('RGB')
img = self.norm(ori_img)
x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x)
pred_img = self.denorm(out.numpy()[0])
pred_img = Image.fromarray(pred_img)
return pred_img
def run(self):
vid = self.input
base_name = os.path.basename(vid).split('.')[0]
output_path = os.path.join(self.output, base_name)
pred_frame_path = os.path.join(output_path, 'frames_pred')
if not os.path.exists(output_path):
os.makedirs(output_path)
if not os.path.exists(pred_frame_path):
os.makedirs(pred_frame_path)
cap = cv2.VideoCapture(vid)
fps = cap.get(cv2.CAP_PROP_FPS)
out_path = dump_frames_ffmpeg(vid, output_path)
frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
for frame in tqdm(frames):
pred_img = self.run_single(frame)
frame_name = os.path.basename(frame)
pred_img.save(os.path.join(pred_frame_path, frame_name))
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(output_path,
'{}_realsr_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path,
str(int(fps)))
return frame_pattern_combined, vid_out_path
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, 'frames_input')
if not os.path.exists(out_full_path):
os.makedirs(out_full_path)
# video file name
outformat = out_full_path + '/%08d.png'
if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [
' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ',
' 0.1 ', ' -start_number ', ' 0 ', outformat
]
else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
if __name__ == '__main__':
paddle.disable_static()
args = parser.parse_args()
predictor = RealSRPredictor(args.input,
args.output,
weight_path=args.weight_path)
frames_path, temp_video_path = predictor.run()
print('output video path:', temp_video_path)
import functools
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class ResidualDenseBlock_5C(nn.Layer):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias_attr=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias_attr=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias_attr=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias_attr=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias_attr=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(paddle.concat((x, x1), 1)))
x3 = self.lrelu(self.conv3(paddle.concat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(paddle.concat((x, x1, x2, x3), 1)))
x5 = self.conv5(paddle.concat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Layer):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class RRDBNet(nn.Layer):
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias_attr=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias_attr=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(
self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(
self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
cd DAIN/pwcnet/correlation_op
# 第一次需要执行
# bash make.shap
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c 'import paddle; print(paddle.sysconfig.get_lib())'`
export PYTHONPATH=$PYTHONPATH:`pwd`
cd -
# 模型说明
# 目前包含DAIN(插帧模型),DeOldify(上色模型),DeepRemaster(去噪与上色模型),EDVR(基于连续帧(视频)超分辨率模型),RealSR(基于图片的超分辨率模型)
# 参数说明
# input 输入视频的路径
# output 输出视频保存的路径
# proccess_order 使用模型的顺序
# proccess_order 要使用的模型及顺序
python tools/main.py \
--input input.mp4 --output output --proccess_order DAIN DeepRemaster DeOldify EDVR
python tools/video-enhance.py \
--input input.mp4 --output output --proccess_order DeOldify RealSR
......@@ -7,53 +7,109 @@ import paddle
from DAIN.predict import VideoFrameInterp
from DeepRemaster.predict import DeepReasterPredictor
from DeOldify.predict import DeOldifyPredictor
from RealSR.predict import RealSRPredictor
from EDVR.predict import EDVRPredictor
parser = argparse.ArgumentParser(description='Fix video')
parser.add_argument('--input', type=str, default=None, help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--DAIN_weight', type=str, default=None, help='Path to model weight')
parser.add_argument('--DeepRemaster_weight', type=str, default=None, help='Path to model weight')
parser.add_argument('--DeOldify_weight', type=str, default=None, help='Path to model weight')
parser.add_argument('--EDVR_weight', type=str, default=None, help='Path to model weight')
parser.add_argument('--input', type=str, default=None, help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--DAIN_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--DeepRemaster_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--DeOldify_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--RealSR_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--EDVR_weight',
type=str,
default=None,
help='Path to model weight')
# DAIN args
parser.add_argument('--time_step', type=float, default=0.5, help='choose the time steps')
parser.add_argument('--time_step',
type=float,
default=0.5,
help='choose the time steps')
# DeepRemaster args
parser.add_argument('--reference_dir', type=str, default=None, help='Path to the reference image directory')
parser.add_argument('--colorization', action='store_true', default=False, help='Remaster with colorization')
parser.add_argument('--mindim', type=int, default=360, help='Length of minimum image edges')
#process order support model name:[DAIN, DeepRemaster, DeOldify, EDVR]
parser.add_argument('--proccess_order', type=str, default='none', nargs='+', help='Process order')
parser.add_argument('--reference_dir',
type=str,
default=None,
help='Path to the reference image directory')
parser.add_argument('--colorization',
action='store_true',
default=False,
help='Remaster with colorization')
parser.add_argument('--mindim',
type=int,
default=360,
help='Length of minimum image edges')
# DeOldify args
parser.add_argument('--render_factor',
type=int,
default=32,
help='model inputsize=render_factor*16')
#process order support model name:[DAIN, DeepRemaster, DeOldify, RealSR, EDVR]
parser.add_argument('--proccess_order',
type=str,
default='none',
nargs='+',
help='Process order')
if __name__ == "__main__":
args = parser.parse_args()
orders = args.proccess_order
temp_video_path = None
for order in orders:
print('Model {} proccess start..'.format(order))
if temp_video_path is None:
temp_video_path = args.input
if order == 'DAIN':
predictor = VideoFrameInterp(args.time_step, args.DAIN_weight,
temp_video_path, output_path=args.output)
predictor = VideoFrameInterp(args.time_step,
args.DAIN_weight,
temp_video_path,
output_path=args.output)
frames_path, temp_video_path = predictor.run()
elif order == 'DeepRemaster':
paddle.disable_static()
predictor = DeepReasterPredictor(temp_video_path, args.output, weight_path=args.DeepRemaster_weight,
colorization=args.colorization, reference_dir=args.reference_dir, mindim=args.mindim)
predictor = DeepReasterPredictor(
temp_video_path,
args.output,
weight_path=args.DeepRemaster_weight,
colorization=args.colorization,
reference_dir=args.reference_dir,
mindim=args.mindim)
frames_path, temp_video_path = predictor.run()
paddle.enable_static()
elif order == 'DeOldify':
elif order == 'DeOldify':
paddle.disable_static()
predictor = DeOldifyPredictor(temp_video_path, args.output, weight_path=args.DeOldify_weight)
predictor = DeOldifyPredictor(temp_video_path,
args.output,
weight_path=args.DeOldify_weight)
frames_path, temp_video_path = predictor.run()
paddle.enable_static()
elif order == 'RealSR':
paddle.disable_static()
predictor = RealSRPredictor(temp_video_path,
args.output,
weight_path=args.RealSR_weight)
frames_path, temp_video_path = predictor.run()
paddle.enable_static()
elif order == 'EDVR':
predictor = EDVRPredictor(temp_video_path, args.output, weight_path=args.EDVR_weight)
predictor = EDVRPredictor(temp_video_path,
args.output,
weight_path=args.EDVR_weight)
frames_path, temp_video_path = predictor.run()
print('Model {} output frames path:'.format(order), frames_path)
print('Model {} output video path:'.format(order), temp_video_path)
print('Model {} proccess done!'.format(order))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部