提交 9ed134f8 编写于 作者: L LielinJiang

Merge branch 'master' of https://github.com/PaddlePaddle/PaddleGAN into sr

...@@ -8,6 +8,7 @@ import time ...@@ -8,6 +8,7 @@ import time
import glob import glob
import numpy as np import numpy as np
from imageio import imread, imsave from imageio import imread, imsave
from tqdm import tqdm
import cv2 import cv2
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -175,8 +176,7 @@ class VideoFrameInterp(object): ...@@ -175,8 +176,7 @@ class VideoFrameInterp(object):
if not os.path.exists(os.path.join(frame_path_combined, vidname)): if not os.path.exists(os.path.join(frame_path_combined, vidname)):
os.makedirs(os.path.join(frame_path_combined, vidname)) os.makedirs(os.path.join(frame_path_combined, vidname))
for i in range(frame_num - 1): for i in tqdm(range(frame_num - 1)):
print(frames[i])
first = frames[i] first = frames[i]
second = frames[i + 1] second = frames[i + 1]
...@@ -208,12 +208,10 @@ class VideoFrameInterp(object): ...@@ -208,12 +208,10 @@ class VideoFrameInterp(object):
assert (X0.shape[1] == X1.shape[1]) assert (X0.shape[1] == X1.shape[1])
assert (X0.shape[2] == X1.shape[2]) assert (X0.shape[2] == X1.shape[2])
print("size before padding ", X0.shape)
X0 = np.pad(X0, ((0,0), (padding_top, padding_bottom), \ X0 = np.pad(X0, ((0,0), (padding_top, padding_bottom), \
(padding_left, padding_right)), mode='edge') (padding_left, padding_right)), mode='edge')
X1 = np.pad(X1, ((0,0), (padding_top, padding_bottom), \ X1 = np.pad(X1, ((0,0), (padding_top, padding_bottom), \
(padding_left, padding_right)), mode='edge') (padding_left, padding_right)), mode='edge')
print("size after padding ", X0.shape)
X0 = np.expand_dims(X0, axis=0) X0 = np.expand_dims(X0, axis=0)
X1 = np.expand_dims(X1, axis=0) X1 = np.expand_dims(X1, axis=0)
...@@ -233,8 +231,6 @@ class VideoFrameInterp(object): ...@@ -233,8 +231,6 @@ class VideoFrameInterp(object):
proc_timer.update(time.time() - proc_end) proc_timer.update(time.time() - proc_end)
tot_timer.update(time.time() - end) tot_timer.update(time.time() - end)
end = time.time() end = time.time()
print("*********** current image process time \t " +
str(time.time() - proc_end) + "s *********")
y_ = [ y_ = [
np.transpose( np.transpose(
......
...@@ -17,8 +17,7 @@ import numpy as np ...@@ -17,8 +17,7 @@ import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, Conv2DTranspose from paddle.fluid.dygraph import Conv2D, Conv2DTranspose
from paddle.fluid.contrib import correlation
from .correlation_op.correlation import correlation
__all__ = ['pwc_dc_net'] __all__ = ['pwc_dc_net']
......
...@@ -22,7 +22,7 @@ class AverageMeter(object): ...@@ -22,7 +22,7 @@ class AverageMeter(object):
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0] vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, vid_name) out_full_path = os.path.join(outpath, vid_name)
...@@ -55,30 +55,29 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): ...@@ -55,30 +55,29 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd) cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0: if os.system(cmd) == 0:
print('Video: {} done'.format(vid_name)) pass
else: else:
print('Video: {} error'.format(vid_name)) print('ffmpeg process video: {} error'.format(vid_name))
print('')
sys.stdout.flush() sys.stdout.flush()
return out_full_path return out_full_path
def frames_to_video_ffmpeg(framepath, videopath, r): def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [ cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ', ' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
] ]
cmd = ''.join(cmd) cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0: if os.system(cmd) == 0:
print('Video: {} done'.format(videopath)) pass
else: else:
print('Video: {} error'.format(videopath)) print('ffmpeg process video: {} error'.format(videopath))
print('')
sys.stdout.flush() sys.stdout.flush()
...@@ -99,7 +98,8 @@ def combine_frames(input, interpolated, combined, num_frames): ...@@ -99,7 +98,8 @@ def combine_frames(input, interpolated, combined, num_frames):
for k in range(num_frames): for k in range(num_frames):
src = frames2[i * num_frames + k] src = frames2[i * num_frames + k]
dst = os.path.join( dst = os.path.join(
combined, '{:08d}.png'.format(i * (num_frames + 1) + k + 1)) combined,
'{:08d}.png'.format(i * (num_frames + 1) + k + 1))
shutil.copy2(src, dst) shutil.copy2(src, dst)
except Exception as e: except Exception as e:
print(e) print(e)
......
...@@ -3,14 +3,16 @@ import numpy as np ...@@ -3,14 +3,16 @@ import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
def is_listy(x): def is_listy(x):
return isinstance(x, (tuple,list)) return isinstance(x, (tuple, list))
class Hook(): class Hook():
"Create a hook on `m` with `hook_func`." "Create a hook on `m` with `hook_func`."
def __init__(self, m, hook_func, is_forward=True, detach=True): def __init__(self, m, hook_func, is_forward=True, detach=True):
self.hook_func,self.detach,self.stored = hook_func,detach,None self.hook_func, self.detach, self.stored = hook_func, detach, None
f = m.register_forward_post_hook if is_forward else m.register_backward_hook f = m.register_forward_post_hook if is_forward else m.register_backward_hook
self.hook = f(self.hook_fn) self.hook = f(self.hook_fn)
self.removed = False self.removed = False
...@@ -18,64 +20,90 @@ class Hook(): ...@@ -18,64 +20,90 @@ class Hook():
def hook_fn(self, module, input, output): def hook_fn(self, module, input, output):
"Applies `hook_func` to `module`, `input`, `output`." "Applies `hook_func` to `module`, `input`, `output`."
if self.detach: if self.detach:
input = (o.detach() for o in input ) if is_listy(input ) else input.detach() input = (o.detach()
output = (o.detach() for o in output) if is_listy(output) else output.detach() for o in input) if is_listy(input) else input.detach()
output = (o.detach()
for o in output) if is_listy(output) else output.detach()
self.stored = self.hook_func(module, input, output) self.stored = self.hook_func(module, input, output)
def remove(self): def remove(self):
"Remove the hook from the model." "Remove the hook from the model."
if not self.removed: if not self.removed:
self.hook.remove() self.hook.remove()
self.removed=True self.removed = True
def __enter__(self, *args):
return self
def __exit__(self, *args):
self.remove()
def __enter__(self, *args): return self
def __exit__(self, *args): self.remove()
class Hooks(): class Hooks():
"Create several hooks on the modules in `ms` with `hook_func`." "Create several hooks on the modules in `ms` with `hook_func`."
def __init__(self, ms, hook_func, is_forward=True, detach=True): def __init__(self, ms, hook_func, is_forward=True, detach=True):
self.hooks = [] self.hooks = []
try: try:
for m in ms: for m in ms:
self.hooks.append(Hook(m, hook_func, is_forward, detach)) self.hooks.append(Hook(m, hook_func, is_forward, detach))
except Exception as e: except Exception as e:
print(e) pass
def __getitem__(self, i: int) -> Hook:
return self.hooks[i]
def __len__(self) -> int:
return len(self.hooks)
def __iter__(self):
return iter(self.hooks)
def __getitem__(self,i:int)->Hook: return self.hooks[i]
def __len__(self)->int: return len(self.hooks)
def __iter__(self): return iter(self.hooks)
@property @property
def stored(self): return [o.stored for o in self] def stored(self):
return [o.stored for o in self]
def remove(self): def remove(self):
"Remove the hooks from the model." "Remove the hooks from the model."
for h in self.hooks: h.remove() for h in self.hooks:
h.remove()
def __enter__(self, *args): return self def __enter__(self, *args):
def __exit__ (self, *args): self.remove() return self
def _hook_inner(m,i,o): return o if isinstance(o, paddle.framework.Variable) else o if is_listy(o) else list(o) def __exit__(self, *args):
self.remove()
def hook_output (module, detach=True, grad=False):
def _hook_inner(m, i, o):
return o if isinstance(
o, paddle.framework.Variable) else o if is_listy(o) else list(o)
def hook_output(module, detach=True, grad=False):
"Return a `Hook` that stores activations of `module` in `self.stored`" "Return a `Hook` that stores activations of `module` in `self.stored`"
return Hook(module, _hook_inner, detach=detach, is_forward=not grad) return Hook(module, _hook_inner, detach=detach, is_forward=not grad)
def hook_outputs(modules, detach=True, grad=False): def hook_outputs(modules, detach=True, grad=False):
"Return `Hooks` that store activations of all `modules` in `self.stored`" "Return `Hooks` that store activations of all `modules` in `self.stored`"
return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad) return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)
def model_sizes(m, size=(64,64)):
def model_sizes(m, size=(64, 64)):
"Pass a dummy input through the model `m` to get the various sizes of activations." "Pass a dummy input through the model `m` to get the various sizes of activations."
with hook_outputs(m) as hooks: with hook_outputs(m) as hooks:
x = dummy_eval(m, size) x = dummy_eval(m, size)
return [o.stored.shape for o in hooks] return [o.stored.shape for o in hooks]
def dummy_eval(m, size=(64,64)):
def dummy_eval(m, size=(64, 64)):
"Pass a `dummy_batch` in evaluation mode in `m` with `size`." "Pass a `dummy_batch` in evaluation mode in `m` with `size`."
m.eval() m.eval()
return m(dummy_batch(size)) return m(dummy_batch(size))
def dummy_batch(size=(64,64), ch_in=3):
def dummy_batch(size=(64, 64), ch_in=3):
"Create a dummy batch to go through `m` with `size`." "Create a dummy batch to go through `m` with `size`."
arr = np.random.rand(1, ch_in, *size).astype('float32') * 2 - 1 arr = np.random.rand(1, ch_in, *size).astype('float32') * 2 - 1
return paddle.to_tensor(arr) return paddle.to_tensor(arr)
此差异已折叠。
...@@ -20,35 +20,44 @@ from paddle.utils.download import get_path_from_url ...@@ -20,35 +20,44 @@ from paddle.utils.download import get_path_from_url
parser = argparse.ArgumentParser(description='DeOldify') parser = argparse.ArgumentParser(description='DeOldify')
parser.add_argument('--input', type=str, default='none', help='Input video') parser.add_argument('--input', type=str, default='none', help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir') parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--render_factor',
type=int,
default=32,
help='model inputsize=render_factor*16')
parser.add_argument('--weight_path', parser.add_argument('--weight_path',
type=str, type=str,
default='none', default=None,
help='Path to the reference image directory') help='Path to the reference image directory')
DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
def frames_to_video_ffmpeg(framepath, videopath, r): def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [ cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ', ' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
] ]
cmd = ''.join(cmd) cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0: if os.system(cmd) == 0:
print('Video: {} done'.format(videopath)) pass
else: else:
print('Video: {} error'.format(videopath)) print('ffmpeg process video: {} error'.format(videopath))
print('')
sys.stdout.flush() sys.stdout.flush()
class DeOldifyPredictor(): class DeOldifyPredictor():
def __init__(self, input, output, batch_size=1, weight_path=None): def __init__(self,
input,
output,
batch_size=1,
weight_path=None,
render_factor=32):
self.input = input self.input = input
self.output = os.path.join(output, 'DeOldify') self.output = os.path.join(output, 'DeOldify')
self.render_factor = render_factor
self.model = build_model() self.model = build_model()
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(DeOldify_weight_url, cur_path) weight_path = get_path_from_url(DeOldify_weight_url, cur_path)
...@@ -93,7 +102,7 @@ class DeOldifyPredictor(): ...@@ -93,7 +102,7 @@ class DeOldifyPredictor():
def run_single(self, img_path): def run_single(self, img_path):
ori_img = Image.open(img_path).convert('LA').convert('RGB') ori_img = Image.open(img_path).convert('LA').convert('RGB')
img = self.norm(ori_img) img = self.norm(ori_img, self.render_factor)
x = paddle.to_tensor(img[np.newaxis, ...]) x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x) out = self.model(x)
...@@ -139,7 +148,7 @@ class DeOldifyPredictor(): ...@@ -139,7 +148,7 @@ class DeOldifyPredictor():
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0] vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, 'frames_input') out_full_path = os.path.join(outpath, 'frames_input')
...@@ -158,23 +167,24 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): ...@@ -158,23 +167,24 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd) cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0: if os.system(cmd) == 0:
print('Video: {} done'.format(vid_name)) pass
else: else:
print('Video: {} error'.format(vid_name)) print('ffmpeg process video: {} error'.format(vid_name))
print('')
sys.stdout.flush() sys.stdout.flush()
return out_full_path return out_full_path
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_imperative() paddle.disable_static()
args = parser.parse_args() args = parser.parse_args()
predictor = DeOldifyPredictor(args.input, predictor = DeOldifyPredictor(args.input,
args.output, args.output,
weight_path=args.weight_path) weight_path=args.weight_path,
render_factor=args.render_factor)
frames_path, temp_video_path = predictor.run() frames_path, temp_video_path = predictor.run()
print('output video path:', temp_video_path) print('output video path:', temp_video_path)
...@@ -2,19 +2,20 @@ import cv2 ...@@ -2,19 +2,20 @@ import cv2
import numpy as np import numpy as np
def read_img(path, size=None, is_gt=False): def read_img(path, size=None, is_gt=False):
"""read image by cv2 """read image by cv2
return: Numpy float32, HWC, BGR, [0,1]""" return: Numpy float32, HWC, BGR, [0,1]"""
# print('debug:', path)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255. img = img.astype(np.float32) / 255.
if img.ndim == 2: if img.ndim == 2:
img = np.expand_dims(img, axis=2) img = np.expand_dims(img, axis=2)
if img.shape[2] > 3: if img.shape[2] > 3:
img = img[:, :, :3] img = img[:, :, :3]
return img return img
def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'): def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'):
"""Generate an index list for reading N frames from a sequence of images """Generate an index list for reading N frames from a sequence of images
...@@ -62,7 +63,7 @@ def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'): ...@@ -62,7 +63,7 @@ def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'):
else: else:
add_idx = i add_idx = i
return_l.append(add_idx) return_l.append(add_idx)
# name_b = '{:08d}'.format(crt_i) # name_b = '{:08d}'.format(crt_i)
return return_l return return_l
...@@ -70,7 +71,6 @@ class EDVRDataset: ...@@ -70,7 +71,6 @@ class EDVRDataset:
def __init__(self, frame_paths): def __init__(self, frame_paths):
self.frames = frame_paths self.frames = frame_paths
def __getitem__(self, index): def __getitem__(self, index):
indexs = get_test_neighbor_frames(index, 5, len(self.frames)) indexs = get_test_neighbor_frames(index, 5, len(self.frames))
frame_list = [] frame_list = []
...@@ -79,7 +79,6 @@ class EDVRDataset: ...@@ -79,7 +79,6 @@ class EDVRDataset:
frame_list.append(img) frame_list.append(img)
img_LQs = np.stack(frame_list, axis=0) img_LQs = np.stack(frame_list, axis=0)
print('img:', img_LQs.shape)
# BGR to RGB, HWC to CHW, numpy to tensor # BGR to RGB, HWC to CHW, numpy to tensor
img_LQs = img_LQs[:, :, :, [2, 1, 0]] img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
...@@ -87,4 +86,4 @@ class EDVRDataset: ...@@ -87,4 +86,4 @@ class EDVRDataset:
return img_LQs, self.frames[index] return img_LQs, self.frames[index]
def __len__(self): def __len__(self):
return len(self.frames) return len(self.frames)
\ No newline at end of file
...@@ -27,6 +27,7 @@ import numpy as np ...@@ -27,6 +27,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import cv2 import cv2
from tqdm import tqdm
from data import EDVRDataset from data import EDVRDataset
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
...@@ -52,7 +53,6 @@ def parse_args(): ...@@ -52,7 +53,6 @@ def parse_args():
def get_img(pred): def get_img(pred):
print('pred shape', pred.shape)
pred = pred.squeeze() pred = pred.squeeze()
pred = np.clip(pred, a_min=0., a_max=1.0) pred = np.clip(pred, a_min=0., a_max=1.0)
pred = pred * 255 pred = pred * 255
...@@ -72,7 +72,7 @@ def save_img(img, framename): ...@@ -72,7 +72,7 @@ def save_img(img, framename):
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0] vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, 'frames_input') out_full_path = os.path.join(outpath, 'frames_input')
...@@ -91,30 +91,29 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None): ...@@ -91,30 +91,29 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat] cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd) cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0: if os.system(cmd) == 0:
print('Video: {} done'.format(vid_name)) pass
else: else:
print('Video: {} error'.format(vid_name)) print('ffmpeg process video: {} error'.format(vid_name))
print('')
sys.stdout.flush() sys.stdout.flush()
return out_full_path return out_full_path
def frames_to_video_ffmpeg(framepath, videopath, r): def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -loglevel ', ' error '] ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [ cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ', ' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath ' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
] ]
cmd = ''.join(cmd) cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0: if os.system(cmd) == 0:
print('Video: {} done'.format(videopath)) pass
else: else:
print('Video: {} error'.format(videopath)) print('ffmpeg process video: {} error'.format(videopath))
print('')
sys.stdout.flush() sys.stdout.flush()
...@@ -164,7 +163,7 @@ class EDVRPredictor: ...@@ -164,7 +163,7 @@ class EDVRPredictor:
periods = [] periods = []
cur_time = time.time() cur_time = time.time()
for infer_iter, data in enumerate(dataset): for infer_iter, data in enumerate(tqdm(dataset)):
data_feed_in = [data[0]] data_feed_in = [data[0]]
infer_outs = self.exe.run( infer_outs = self.exe.run(
...@@ -185,7 +184,7 @@ class EDVRPredictor: ...@@ -185,7 +184,7 @@ class EDVRPredictor:
period = cur_time - prev_time period = cur_time - prev_time
periods.append(period) periods.append(period)
print('Processed {} samples'.format(infer_iter + 1)) # print('Processed {} samples'.format(infer_iter + 1))
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(self.output, vid_out_path = os.path.join(self.output,
'{}_edvr_out.mp4'.format(base_name)) '{}_edvr_out.mp4'.format(base_name))
......
import os
import sys
cur_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(cur_path)
import cv2
import glob
import argparse
import numpy as np
import paddle
import pickle
from PIL import Image
from tqdm import tqdm
from sr_model import RRDBNet
from paddle.utils.download import get_path_from_url
parser = argparse.ArgumentParser(description='RealSR')
parser.add_argument('--input', type=str, default='none', help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--weight_path',
type=str,
default=None,
help='Path to the reference image directory')
RealSR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams'
def frames_to_video_ffmpeg(framepath, videopath, r):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
cmd = ffmpeg + [
' -r ', r, ' -f ', ' image2 ', ' -i ', framepath, ' -vcodec ',
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
class RealSRPredictor():
def __init__(self, input, output, batch_size=1, weight_path=None):
self.input = input
self.output = os.path.join(output, 'RealSR')
self.model = RRDBNet(3, 3, 64, 23)
if weight_path is None:
weight_path = get_path_from_url(RealSR_weight_url, cur_path)
state_dict, _ = paddle.load(weight_path)
self.model.load_dict(state_dict)
self.model.eval()
def norm(self, img):
img = np.array(img).transpose([2, 0, 1]).astype('float32') / 255.0
return img.astype('float32')
def denorm(self, img):
img = img.transpose((1, 2, 0))
return (img * 255).clip(0, 255).astype('uint8')
def run_single(self, img_path):
ori_img = Image.open(img_path).convert('RGB')
img = self.norm(ori_img)
x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x)
pred_img = self.denorm(out.numpy()[0])
pred_img = Image.fromarray(pred_img)
return pred_img
def run(self):
vid = self.input
base_name = os.path.basename(vid).split('.')[0]
output_path = os.path.join(self.output, base_name)
pred_frame_path = os.path.join(output_path, 'frames_pred')
if not os.path.exists(output_path):
os.makedirs(output_path)
if not os.path.exists(pred_frame_path):
os.makedirs(pred_frame_path)
cap = cv2.VideoCapture(vid)
fps = cap.get(cv2.CAP_PROP_FPS)
out_path = dump_frames_ffmpeg(vid, output_path)
frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
for frame in tqdm(frames):
pred_img = self.run_single(frame)
frame_name = os.path.basename(frame)
pred_img.save(os.path.join(pred_frame_path, frame_name))
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(output_path,
'{}_realsr_out.mp4'.format(base_name))
frames_to_video_ffmpeg(frame_pattern_combined, vid_out_path,
str(int(fps)))
return frame_pattern_combined, vid_out_path
def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = vid_path.split('/')[-1].split('.')[0]
out_full_path = os.path.join(outpath, 'frames_input')
if not os.path.exists(out_full_path):
os.makedirs(out_full_path)
# video file name
outformat = out_full_path + '/%08d.png'
if ss is not None and t is not None and r is not None:
cmd = ffmpeg + [
' -ss ', ss, ' -t ', t, ' -i ', vid_path, ' -r ', r, ' -qscale:v ',
' 0.1 ', ' -start_number ', ' 0 ', outformat
]
else:
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
if os.system(cmd) == 0:
pass
else:
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
if __name__ == '__main__':
paddle.disable_static()
args = parser.parse_args()
predictor = RealSRPredictor(args.input,
args.output,
weight_path=args.weight_path)
frames_path, temp_video_path = predictor.run()
print('output video path:', temp_video_path)
import functools
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class ResidualDenseBlock_5C(nn.Layer):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias_attr=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias_attr=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias_attr=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias_attr=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias_attr=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(paddle.concat((x, x1), 1)))
x3 = self.lrelu(self.conv3(paddle.concat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(paddle.concat((x, x1, x2, x3), 1)))
x5 = self.conv5(paddle.concat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Layer):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class RRDBNet(nn.Layer):
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias_attr=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias_attr=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(
self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(
self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
cd DAIN/pwcnet/correlation_op # 模型说明
# 第一次需要执行 # 目前包含DAIN(插帧模型),DeOldify(上色模型),DeepRemaster(去噪与上色模型),EDVR(基于连续帧(视频)超分辨率模型),RealSR(基于图片的超分辨率模型)
# bash make.shap # 参数说明
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python -c 'import paddle; print(paddle.sysconfig.get_lib())'`
export PYTHONPATH=$PYTHONPATH:`pwd`
cd -
# input 输入视频的路径 # input 输入视频的路径
# output 输出视频保存的路径 # output 输出视频保存的路径
# proccess_order 使用模型的顺序 # proccess_order 要使用的模型及顺序
python tools/main.py \ python tools/video-enhance.py \
--input input.mp4 --output output --proccess_order DAIN DeepRemaster DeOldify EDVR --input input.mp4 --output output --proccess_order DeOldify RealSR
...@@ -7,53 +7,109 @@ import paddle ...@@ -7,53 +7,109 @@ import paddle
from DAIN.predict import VideoFrameInterp from DAIN.predict import VideoFrameInterp
from DeepRemaster.predict import DeepReasterPredictor from DeepRemaster.predict import DeepReasterPredictor
from DeOldify.predict import DeOldifyPredictor from DeOldify.predict import DeOldifyPredictor
from RealSR.predict import RealSRPredictor
from EDVR.predict import EDVRPredictor from EDVR.predict import EDVRPredictor
parser = argparse.ArgumentParser(description='Fix video') parser = argparse.ArgumentParser(description='Fix video')
parser.add_argument('--input', type=str, default=None, help='Input video') parser.add_argument('--input', type=str, default=None, help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir') parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--DAIN_weight', type=str, default=None, help='Path to model weight') parser.add_argument('--DAIN_weight',
parser.add_argument('--DeepRemaster_weight', type=str, default=None, help='Path to model weight') type=str,
parser.add_argument('--DeOldify_weight', type=str, default=None, help='Path to model weight') default=None,
parser.add_argument('--EDVR_weight', type=str, default=None, help='Path to model weight') help='Path to model weight')
parser.add_argument('--DeepRemaster_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--DeOldify_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--RealSR_weight',
type=str,
default=None,
help='Path to model weight')
parser.add_argument('--EDVR_weight',
type=str,
default=None,
help='Path to model weight')
# DAIN args # DAIN args
parser.add_argument('--time_step', type=float, default=0.5, help='choose the time steps') parser.add_argument('--time_step',
type=float,
default=0.5,
help='choose the time steps')
# DeepRemaster args # DeepRemaster args
parser.add_argument('--reference_dir', type=str, default=None, help='Path to the reference image directory') parser.add_argument('--reference_dir',
parser.add_argument('--colorization', action='store_true', default=False, help='Remaster with colorization') type=str,
parser.add_argument('--mindim', type=int, default=360, help='Length of minimum image edges') default=None,
#process order support model name:[DAIN, DeepRemaster, DeOldify, EDVR] help='Path to the reference image directory')
parser.add_argument('--proccess_order', type=str, default='none', nargs='+', help='Process order') parser.add_argument('--colorization',
action='store_true',
default=False,
help='Remaster with colorization')
parser.add_argument('--mindim',
type=int,
default=360,
help='Length of minimum image edges')
# DeOldify args
parser.add_argument('--render_factor',
type=int,
default=32,
help='model inputsize=render_factor*16')
#process order support model name:[DAIN, DeepRemaster, DeOldify, RealSR, EDVR]
parser.add_argument('--proccess_order',
type=str,
default='none',
nargs='+',
help='Process order')
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
orders = args.proccess_order orders = args.proccess_order
temp_video_path = None temp_video_path = None
for order in orders: for order in orders:
print('Model {} proccess start..'.format(order))
if temp_video_path is None: if temp_video_path is None:
temp_video_path = args.input temp_video_path = args.input
if order == 'DAIN': if order == 'DAIN':
predictor = VideoFrameInterp(args.time_step, args.DAIN_weight, predictor = VideoFrameInterp(args.time_step,
temp_video_path, output_path=args.output) args.DAIN_weight,
temp_video_path,
output_path=args.output)
frames_path, temp_video_path = predictor.run() frames_path, temp_video_path = predictor.run()
elif order == 'DeepRemaster': elif order == 'DeepRemaster':
paddle.disable_static() paddle.disable_static()
predictor = DeepReasterPredictor(temp_video_path, args.output, weight_path=args.DeepRemaster_weight, predictor = DeepReasterPredictor(
colorization=args.colorization, reference_dir=args.reference_dir, mindim=args.mindim) temp_video_path,
args.output,
weight_path=args.DeepRemaster_weight,
colorization=args.colorization,
reference_dir=args.reference_dir,
mindim=args.mindim)
frames_path, temp_video_path = predictor.run() frames_path, temp_video_path = predictor.run()
paddle.enable_static() paddle.enable_static()
elif order == 'DeOldify': elif order == 'DeOldify':
paddle.disable_static() paddle.disable_static()
predictor = DeOldifyPredictor(temp_video_path, args.output, weight_path=args.DeOldify_weight) predictor = DeOldifyPredictor(temp_video_path,
args.output,
weight_path=args.DeOldify_weight)
frames_path, temp_video_path = predictor.run()
paddle.enable_static()
elif order == 'RealSR':
paddle.disable_static()
predictor = RealSRPredictor(temp_video_path,
args.output,
weight_path=args.RealSR_weight)
frames_path, temp_video_path = predictor.run() frames_path, temp_video_path = predictor.run()
paddle.enable_static() paddle.enable_static()
elif order == 'EDVR': elif order == 'EDVR':
predictor = EDVRPredictor(temp_video_path, args.output, weight_path=args.EDVR_weight) predictor = EDVRPredictor(temp_video_path,
args.output,
weight_path=args.EDVR_weight)
frames_path, temp_video_path = predictor.run() frames_path, temp_video_path = predictor.run()
print('Model {} output frames path:'.format(order), frames_path) print('Model {} output frames path:'.format(order), frames_path)
print('Model {} output video path:'.format(order), temp_video_path) print('Model {} output video path:'.format(order), temp_video_path)
print('Model {} proccess done!'.format(order))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册