未验证 提交 3bc13ff3 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #22 from LielinJiang/refine-code

Refine print log and add args
......@@ -55,12 +55,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(vid_name))
pass
else:
print('Video: {} error'.format(vid_name))
print('')
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
......@@ -72,13 +72,12 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(videopath))
pass
else:
print('Video: {} error'.format(videopath))
print('')
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
......
......@@ -3,14 +3,16 @@ import numpy as np
import paddle
import paddle.nn as nn
def is_listy(x):
return isinstance(x, (tuple,list))
return isinstance(x, (tuple, list))
class Hook():
"Create a hook on `m` with `hook_func`."
def __init__(self, m, hook_func, is_forward=True, detach=True):
self.hook_func,self.detach,self.stored = hook_func,detach,None
self.hook_func, self.detach, self.stored = hook_func, detach, None
f = m.register_forward_post_hook if is_forward else m.register_backward_hook
self.hook = f(self.hook_fn)
self.removed = False
......@@ -18,64 +20,90 @@ class Hook():
def hook_fn(self, module, input, output):
"Applies `hook_func` to `module`, `input`, `output`."
if self.detach:
input = (o.detach() for o in input ) if is_listy(input ) else input.detach()
output = (o.detach() for o in output) if is_listy(output) else output.detach()
input = (o.detach()
for o in input) if is_listy(input) else input.detach()
output = (o.detach()
for o in output) if is_listy(output) else output.detach()
self.stored = self.hook_func(module, input, output)
def remove(self):
"Remove the hook from the model."
if not self.removed:
self.hook.remove()
self.removed=True
self.removed = True
def __enter__(self, *args):
return self
def __exit__(self, *args):
self.remove()
def __enter__(self, *args): return self
def __exit__(self, *args): self.remove()
class Hooks():
"Create several hooks on the modules in `ms` with `hook_func`."
def __init__(self, ms, hook_func, is_forward=True, detach=True):
self.hooks = []
try:
for m in ms:
self.hooks.append(Hook(m, hook_func, is_forward, detach))
except Exception as e:
print(e)
pass
def __getitem__(self, i: int) -> Hook:
return self.hooks[i]
def __len__(self) -> int:
return len(self.hooks)
def __iter__(self):
return iter(self.hooks)
def __getitem__(self,i:int)->Hook: return self.hooks[i]
def __len__(self)->int: return len(self.hooks)
def __iter__(self): return iter(self.hooks)
@property
def stored(self): return [o.stored for o in self]
def stored(self):
return [o.stored for o in self]
def remove(self):
"Remove the hooks from the model."
for h in self.hooks: h.remove()
for h in self.hooks:
h.remove()
def __enter__(self, *args): return self
def __exit__ (self, *args): self.remove()
def __enter__(self, *args):
return self
def _hook_inner(m,i,o): return o if isinstance(o, paddle.framework.Variable) else o if is_listy(o) else list(o)
def __exit__(self, *args):
self.remove()
def hook_output (module, detach=True, grad=False):
def _hook_inner(m, i, o):
return o if isinstance(
o, paddle.framework.Variable) else o if is_listy(o) else list(o)
def hook_output(module, detach=True, grad=False):
"Return a `Hook` that stores activations of `module` in `self.stored`"
return Hook(module, _hook_inner, detach=detach, is_forward=not grad)
def hook_outputs(modules, detach=True, grad=False):
"Return `Hooks` that store activations of all `modules` in `self.stored`"
return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)
def model_sizes(m, size=(64,64)):
def model_sizes(m, size=(64, 64)):
"Pass a dummy input through the model `m` to get the various sizes of activations."
with hook_outputs(m) as hooks:
x = dummy_eval(m, size)
return [o.stored.shape for o in hooks]
def dummy_eval(m, size=(64,64)):
def dummy_eval(m, size=(64, 64)):
"Pass a `dummy_batch` in evaluation mode in `m` with `size`."
m.eval()
return m(dummy_batch(size))
def dummy_batch(size=(64,64), ch_in=3):
def dummy_batch(size=(64, 64), ch_in=3):
"Create a dummy batch to go through `m` with `size`."
arr = np.random.rand(1, ch_in, *size).astype('float32') * 2 - 1
return paddle.to_tensor(arr)
......@@ -20,6 +20,10 @@ from paddle.utils.download import get_path_from_url
parser = argparse.ArgumentParser(description='DeOldify')
parser.add_argument('--input', type=str, default='none', help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
parser.add_argument('--render_factor',
type=int,
default=32,
help='model inputsize=render_factor*16')
parser.add_argument('--weight_path',
type=str,
default=None,
......@@ -35,20 +39,25 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(videopath))
pass
else:
print('Video: {} error'.format(videopath))
print('')
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
class DeOldifyPredictor():
def __init__(self, input, output, batch_size=1, weight_path=None):
def __init__(self,
input,
output,
batch_size=1,
weight_path=None,
render_factor=32):
self.input = input
self.output = os.path.join(output, 'DeOldify')
self.render_factor = render_factor
self.model = build_model()
if weight_path is None:
weight_path = get_path_from_url(DeOldify_weight_url, cur_path)
......@@ -93,7 +102,7 @@ class DeOldifyPredictor():
def run_single(self, img_path):
ori_img = Image.open(img_path).convert('LA').convert('RGB')
img = self.norm(ori_img)
img = self.norm(ori_img, self.render_factor)
x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x)
......@@ -158,12 +167,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(vid_name))
pass
else:
print('Video: {} error'.format(vid_name))
print('')
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
......@@ -174,7 +183,8 @@ if __name__ == '__main__':
predictor = DeOldifyPredictor(args.input,
args.output,
weight_path=args.weight_path)
weight_path=args.weight_path,
render_factor=args.render_factor)
frames_path, temp_video_path = predictor.run()
print('output video path:', temp_video_path)
......@@ -91,12 +91,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(vid_name))
pass
else:
print('Video: {} error'.format(vid_name))
print('')
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
......@@ -108,13 +108,12 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(videopath))
pass
else:
print('Video: {} error'.format(videopath))
print('')
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
......
......@@ -34,13 +34,12 @@ def frames_to_video_ffmpeg(framepath, videopath, r):
' libx264 ', ' -pix_fmt ', ' yuv420p ', ' -crf ', ' 16 ', videopath
]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(videopath))
pass
else:
print('Video: {} error'.format(videopath))
print('')
print('ffmpeg process video: {} error'.format(videopath))
sys.stdout.flush()
......@@ -129,12 +128,12 @@ def dump_frames_ffmpeg(vid_path, outpath, r=None, ss=None, t=None):
cmd = ffmpeg + [' -i ', vid_path, ' -start_number ', ' 0 ', outformat]
cmd = ''.join(cmd)
print(cmd)
if os.system(cmd) == 0:
print('Video: {} done'.format(vid_name))
pass
else:
print('Video: {} error'.format(vid_name))
print('')
print('ffmpeg process video: {} error'.format(vid_name))
sys.stdout.flush()
return out_full_path
......
......@@ -51,6 +51,11 @@ parser.add_argument('--mindim',
type=int,
default=360,
help='Length of minimum image edges')
# DeOldify args
parser.add_argument('--render_factor',
type=int,
default=32,
help='model inputsize=render_factor*16')
#process order support model name:[DAIN, DeepRemaster, DeOldify, RealSR, EDVR]
parser.add_argument('--proccess_order',
type=str,
......@@ -65,6 +70,7 @@ if __name__ == "__main__":
temp_video_path = None
for order in orders:
print('Model {} proccess start..'.format(order))
if temp_video_path is None:
temp_video_path = args.input
if order == 'DAIN':
......@@ -106,3 +112,4 @@ if __name__ == "__main__":
print('Model {} output frames path:'.format(order), frames_path)
print('Model {} output video path:'.format(order), temp_video_path)
print('Model {} proccess done!'.format(order))
......@@ -47,7 +47,7 @@ class Trainer:
self.time_count = {}
def distributed_data_parallel(self):
strategy = paddle.prepare_context()
strategy = paddle.distributed.prepare_context()
for name in self.model.model_names:
if isinstance(name, str):
net = getattr(self.model, 'net' + name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部