未验证 提交 4cd9a0b1 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #28 from lijianshe02/master

add remove duplicate frames in DAIN
...@@ -90,4 +90,7 @@ parser.add_argument('--use_cuda', ...@@ -90,4 +90,7 @@ parser.add_argument('--use_cuda',
type=bool, type=bool,
help='use cuda or not') help='use cuda or not')
parser.add_argument('--use_cudnn', default=1, type=int, help='use cudnn or not') parser.add_argument('--use_cudnn', default=1, type=int, help='use cudnn or not')
parser.add_argument('--remove_duplicates',
default=True,
type=bool,
help='remove duplicate frames or not')
...@@ -80,7 +80,8 @@ class VideoFrameInterp(object): ...@@ -80,7 +80,8 @@ class VideoFrameInterp(object):
video_path, video_path,
use_gpu=True, use_gpu=True,
key_frame_thread=0., key_frame_thread=0.,
output_path='output'): output_path='output',
remove_duplicates=True):
self.video_path = video_path self.video_path = video_path
self.output_path = os.path.join(output_path, 'DAIN') self.output_path = os.path.join(output_path, 'DAIN')
if model_path is None: if model_path is None:
...@@ -138,6 +139,8 @@ class VideoFrameInterp(object): ...@@ -138,6 +139,8 @@ class VideoFrameInterp(object):
end = time.time() end = time.time()
frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
if remove_duplicates:
frames = remove_duplicates(out_path)
img = imread(frames[0]) img = imread(frames[0])
...@@ -199,58 +202,51 @@ class VideoFrameInterp(object): ...@@ -199,58 +202,51 @@ class VideoFrameInterp(object):
X0 = img_first.astype('float32').transpose((2, 0, 1)) / 255 X0 = img_first.astype('float32').transpose((2, 0, 1)) / 255
X1 = img_second.astype('float32').transpose((2, 0, 1)) / 255 X1 = img_second.astype('float32').transpose((2, 0, 1)) / 255
if key_frame: assert (X0.shape[1] == X1.shape[1])
y_ = [ assert (X0.shape[2] == X1.shape[2])
np.transpose(255.0 * X0.clip(0, 1.0), (1, 2, 0))
for i in range(num_frames) X0 = np.pad(X0, ((0,0), (padding_top, padding_bottom), \
] (padding_left, padding_right)), mode='edge')
else: X1 = np.pad(X1, ((0,0), (padding_top, padding_bottom), \
assert (X0.shape[1] == X1.shape[1]) (padding_left, padding_right)), mode='edge')
assert (X0.shape[2] == X1.shape[2])
X0 = np.expand_dims(X0, axis=0)
X0 = np.pad(X0, ((0,0), (padding_top, padding_bottom), \ X1 = np.expand_dims(X1, axis=0)
(padding_left, padding_right)), mode='edge')
X1 = np.pad(X1, ((0,0), (padding_top, padding_bottom), \ X0 = np.expand_dims(X0, axis=0)
(padding_left, padding_right)), mode='edge') X1 = np.expand_dims(X1, axis=0)
X0 = np.expand_dims(X0, axis=0) X = np.concatenate((X0, X1), axis=0)
X1 = np.expand_dims(X1, axis=0)
proc_end = time.time()
X0 = np.expand_dims(X0, axis=0) o = self.exe.run(self.program,
X1 = np.expand_dims(X1, axis=0) fetch_list=self.fetch_targets,
feed={"image": X})
X = np.concatenate((X0, X1), axis=0)
y_ = o[0]
proc_end = time.time()
o = self.exe.run(self.program, proc_timer.update(time.time() - proc_end)
fetch_list=self.fetch_targets, tot_timer.update(time.time() - end)
feed={"image": X}) end = time.time()
y_ = o[0] y_ = [
np.transpose(
proc_timer.update(time.time() - proc_end) 255.0 * item.clip(
tot_timer.update(time.time() - end) 0, 1.0)[0, :, padding_top:padding_top + int_height,
end = time.time() padding_left:padding_left + int_width],
(1, 2, 0)) for item in y_
y_ = [ ]
np.transpose( time_offsets = [
255.0 * item.clip( kk * timestep for kk in range(1, 1 + num_frames, 1)
0, 1.0)[0, :, ]
padding_top:padding_top + int_height,
padding_left:padding_left + int_width], count = 1
(1, 2, 0)) for item in y_ for item, time_offset in zip(y_, time_offsets):
] out_dir = os.path.join(
time_offsets = [ frame_path_interpolated, vidname,
kk * timestep for kk in range(1, 1 + num_frames, 1) "{:0>6d}_{:0>4d}.png".format(i, count))
] count = count + 1
imsave(out_dir, np.round(item).astype(np.uint8))
count = 1
for item, time_offset in zip(y_, time_offsets):
out_dir = os.path.join(
frame_path_interpolated, vidname,
"{:0>6d}_{:0>4d}.png".format(i, count))
count = count + 1
imsave(out_dir, np.round(item).astype(np.uint8))
num_frames = int(1.0 / timestep) - 1 num_frames = int(1.0 / timestep) - 1
...@@ -266,14 +262,16 @@ class VideoFrameInterp(object): ...@@ -266,14 +262,16 @@ class VideoFrameInterp(object):
vidname + '.mp4') vidname + '.mp4')
if os.path.exists(video_pattern_output): if os.path.exists(video_pattern_output):
os.remove(video_pattern_output) os.remove(video_pattern_output)
frames2video(frame_pattern_combined, video_pattern_output, frames2video(frame_pattern_combined, video_pattern_output, r2)
r2)
return frame_pattern_combined, video_pattern_output return frame_pattern_combined, video_pattern_output
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
predictor = VideoFrameInterp(args.time_step, args.saved_model, predictor = VideoFrameInterp(args.time_step,
args.video_path, args.output_path) args.saved_model,
args.video_path,
args.output_path,
remove_duplicates=args.remove_duplicates)
predictor.run() predictor.run()
import os, sys import os, sys
import glob import glob
import shutil import shutil
import cv2
class AverageMeter(object): class AverageMeter(object):
...@@ -44,3 +45,34 @@ def combine_frames(input, interpolated, combined, num_frames): ...@@ -44,3 +45,34 @@ def combine_frames(input, interpolated, combined, num_frames):
except Exception as e: except Exception as e:
print(e) print(e)
print(len(frames2), num_frames, i, k, i * num_frames + k) print(len(frames2), num_frames, i, k, i * num_frames + k)
def remove_duplicates(paths):
def dhash(image, hash_size=8):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
resized = cv2.resize(gray, (hash_size + 1, hash_size))
diff = resized[:, 1:] > resized[:, :-1]
return sum([2**i for (i, v) in enumerate(diff.flatten()) if v])
hashes = {}
image_paths = sorted(glob.glob(os.path.join(paths, '*.png')))
for image_path in image_paths:
image = cv2.imread(image_path)
h = dhash(image)
p = hashes.get(h, [])
p.append(image_path)
hashes[h] = p
for (h, hashed_paths) in hashes.items():
if len(hashed_paths) > 1:
for p in hashed_paths[1:]:
os.remove(p)
frames = sorted(glob.glob(os.path.join(paths, '*.png')))
for fid, frame in enumerate(frames):
new_name = '{:08d}'.format(fid) + '.png'
new_name = os.path.join(paths, new_name)
os.rename(frame, new_name)
frames = sorted(glob.glob(os.path.join(paths, '*.png')))
return frames
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册