未验证 提交 61d4939d 编写于 作者: W wangna11BD 提交者: GitHub

add EDVR predictor dynamic (#315)

上级 d81d9cc5
...@@ -119,10 +119,8 @@ if __name__ == "__main__": ...@@ -119,10 +119,8 @@ if __name__ == "__main__":
weight_path=args.RealSR_weight) weight_path=args.RealSR_weight)
frames_path, temp_video_path = predictor.run(temp_video_path) frames_path, temp_video_path = predictor.run(temp_video_path)
elif order == 'EDVR': elif order == 'EDVR':
paddle.enable_static()
predictor = EDVRPredictor(args.output, weight_path=args.EDVR_weight) predictor = EDVRPredictor(args.output, weight_path=args.EDVR_weight)
frames_path, temp_video_path = predictor.run(temp_video_path) frames_path, temp_video_path = predictor.run(temp_video_path)
paddle.disable_static()
print('Model {} output frames path:'.format(order), frames_path) print('Model {} output frames path:'.format(order), frames_path)
print('Model {} output video path:'.format(order), temp_video_path) print('Model {} output video path:'.format(order), temp_video_path)
......
...@@ -19,12 +19,15 @@ import glob ...@@ -19,12 +19,15 @@ import glob
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
import paddle
from paddle.io import Dataset, DataLoader
from ppgan.utils.download import get_path_from_url from ppgan.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames from ppgan.utils.video import frames2video, video2frames
from ppgan.models.generators import EDVRNet
from .base_predictor import BasePredictor from .base_predictor import BasePredictor
EDVR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar' EDVR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_SRx4.pdparams'
def get_img(pred): def get_img(pred):
...@@ -110,7 +113,7 @@ def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'): ...@@ -110,7 +113,7 @@ def get_test_neighbor_frames(crt_i, N, max_n, padding='new_info'):
return return_l return return_l
class EDVRDataset: class EDVRDataset(Dataset):
def __init__(self, frame_paths): def __init__(self, frame_paths):
self.frames = frame_paths self.frames = frame_paths
...@@ -133,16 +136,15 @@ class EDVRDataset: ...@@ -133,16 +136,15 @@ class EDVRDataset:
class EDVRPredictor(BasePredictor): class EDVRPredictor(BasePredictor):
def __init__(self, output='output', weight_path=None): def __init__(self, output='output', weight_path=None, bs=1):
self.input = input self.input = input
self.output = os.path.join(output, 'EDVR') self.output = os.path.join(output, 'EDVR')
self.bs = bs
self.model = EDVRNet(nf=128, back_RBs=40)
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(EDVR_WEIGHT_URL) weight_path = get_path_from_url(EDVR_WEIGHT_URL)
self.model.set_dict(paddle.load(weight_path)['generator'])
self.weight_path = weight_path self.model.eval()
self.build_inference_model()
def run(self, video_path): def run(self, video_path):
vid = video_path vid = video_path
...@@ -163,23 +165,23 @@ class EDVRPredictor(BasePredictor): ...@@ -163,23 +165,23 @@ class EDVRPredictor(BasePredictor):
frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
dataset = EDVRDataset(frames) test_dataset = EDVRDataset(frames)
dataset = DataLoader(test_dataset, batch_size=self.bs, num_workers=2)
periods = [] periods = []
cur_time = time.time() cur_time = time.time()
for infer_iter, data in enumerate(tqdm(dataset)): for infer_iter, data in enumerate(tqdm(dataset)):
data_feed_in = [data[0]] data_feed_in = paddle.to_tensor(data[0])
with paddle.no_grad():
outs = self.base_forward(np.array(data_feed_in)) outs = self.model(data_feed_in).numpy()
infer_result_list = [outs[i, :, :, :] for i in range(self.bs)]
infer_result_list = [item for item in outs]
frame_path = data[1] frame_path = data[1]
for i in range(self.bs):
img_i = get_img(infer_result_list[0]) img_i = get_img(infer_result_list[i])
save_img( save_img(
img_i, img_i,
os.path.join(pred_frame_path, os.path.basename(frame_path))) os.path.join(pred_frame_path,
os.path.basename(frame_path[i])))
prev_time = cur_time prev_time = cur_time
cur_time = time.time() cur_time = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册