# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # #Licensed under the Apache License, Version 2.0 (the "License"); #you may not use this file except in compliance with the License. #You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # #Unless required by applicable law or agreed to in writing, software #distributed under the License is distributed on an "AS IS" BASIS, #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #See the License for the specific language governing permissions and #limitations under the License. import os import cv2 import time import glob import numpy as np from tqdm import tqdm from ppgan.utils.download import get_path_from_url from ppgan.utils.video import frames2video, video2frames from .base_predictor import BasePredictor EDVR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar' def get_img(pred): pred = pred.squeeze() pred = np.clip(pred, a_min=0., a_max=1.0) pred = pred * 255 pred = pred.round() pred = pred.astype('uint8') pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc pred = pred[:, :, ::-1] # rgb -> bgr return pred def save_img(img, framename): dirname = os.path.dirname(framename) if not os.path.exists(dirname): os.makedirs(dirname) cv2.imwrite(framename, img) def read_img(path, size=None, is_gt=False): """read image by cv2 return: Numpy float32, HWC, BGR, [0,1]""" img = cv2.imread(path, cv2.IMREAD_UNCHANGED) img = img.astype(np.float32) / 255. if img.ndim == 2: img = np.expand_dims(img, axis=2) if img.shape[2] > 3: img = img[:, :, :3] return img def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'): """Generate an index list for reading N frames from a sequence of images Args: crt_i (int): current center index max_n (int): max number of the sequence of images (calculated from 1) N (int): reading N frames padding (str): padding mode, one of replicate | reflection | new_info | circle Example: crt_i = 0, N = 5 replicate: [0, 0, 0, 1, 2] reflection: [2, 1, 0, 1, 2] new_info: [4, 3, 0, 1, 2] circle: [3, 4, 0, 1, 2] Returns: return_l (list [int]): a list of indexes """ max_n = max_n - 1 n_pad = N // 2 return_l = [] for i in range(crt_i - n_pad, crt_i + n_pad + 1): if i < 0: if padding == 'replicate': add_idx = 0 elif padding == 'reflection': add_idx = -i elif padding == 'new_info': add_idx = (crt_i + n_pad) + (-i) elif padding == 'circle': add_idx = N + i else: raise ValueError('Wrong padding mode') elif i > max_n: if padding == 'replicate': add_idx = max_n elif padding == 'reflection': add_idx = max_n * 2 - i elif padding == 'new_info': add_idx = (crt_i - n_pad) - (i - max_n) elif padding == 'circle': add_idx = i - N else: raise ValueError('Wrong padding mode') else: add_idx = i return_l.append(add_idx) return return_l class EDVRDataset: def __init__(self, frame_paths): self.frames = frame_paths def __getitem__(self, index): indexs = get_test_neighbor_frames(index, 5, len(self.frames)) frame_list = [] for i in indexs: img = read_img(self.frames[i]) frame_list.append(img) img_LQs = np.stack(frame_list, axis=0) # BGR to RGB, HWC to CHW, numpy to tensor img_LQs = img_LQs[:, :, :, [2, 1, 0]] img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') return img_LQs, self.frames[index] def __len__(self): return len(self.frames) class EDVRPredictor(BasePredictor): def __init__(self, output='output', weight_path=None): self.input = input self.output = os.path.join(output, 'EDVR') if weight_path is None: weight_path = get_path_from_url(EDVR_WEIGHT_URL) self.weight_path = weight_path self.build_inference_model() def run(self, video_path): vid = video_path base_name = os.path.basename(vid).split('.')[0] output_path = os.path.join(self.output, base_name) pred_frame_path = os.path.join(output_path, 'frames_pred') if not os.path.exists(output_path): os.makedirs(output_path) if not os.path.exists(pred_frame_path): os.makedirs(pred_frame_path) cap = cv2.VideoCapture(vid) fps = cap.get(cv2.CAP_PROP_FPS) out_path = video2frames(vid, output_path) frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) dataset = EDVRDataset(frames) periods = [] cur_time = time.time() for infer_iter, data in enumerate(tqdm(dataset)): data_feed_in = [data[0]] outs = self.base_forward(np.array(data_feed_in)) infer_result_list = [item for item in outs] frame_path = data[1] img_i = get_img(infer_result_list[0]) save_img( img_i, os.path.join(pred_frame_path, os.path.basename(frame_path))) prev_time = cur_time cur_time = time.time() period = cur_time - prev_time periods.append(period) # print('Processed {} samples'.format(infer_iter + 1)) frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') vid_out_path = os.path.join(self.output, '{}_edvr_out.mp4'.format(base_name)) frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) return frame_pattern_combined, vid_out_path