From 1744b3a6e950371f0e15012105b1bdfdf57c4486 Mon Sep 17 00:00:00 2001 From: LielinJiang Date: Mon, 12 Oct 2020 03:51:45 +0000 Subject: [PATCH] add run image --- ppgan/apps/base_predictor.py | 14 ++++++++------ ppgan/apps/deoldify_predictor.py | 32 ++++++++++++++++++++++++------- ppgan/apps/realsr_predictor.py | 33 ++++++++++++++++++++++++-------- 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/ppgan/apps/base_predictor.py b/ppgan/apps/base_predictor.py index 8727143..e250391 100644 --- a/ppgan/apps/base_predictor.py +++ b/ppgan/apps/base_predictor.py @@ -13,6 +13,7 @@ #limitations under the License. import os +import cv2 import paddle @@ -60,11 +61,12 @@ class BasePredictor(object): return out - def preprocess(self, inputs): - pass - - def postprocess(self, inputs): - pass + def is_video(self, input): + try: + cv2.VideoCapture(input) + return True + except: + return False def run(self): - pass + raise NotImplementedError diff --git a/ppgan/apps/deoldify_predictor.py b/ppgan/apps/deoldify_predictor.py index e862c28..d6d3371 100644 --- a/ppgan/apps/deoldify_predictor.py +++ b/ppgan/apps/deoldify_predictor.py @@ -77,8 +77,14 @@ class DeOldifyPredictor(BasePredictor): final = Image.fromarray(final) return final - def run_single(self, img_path): - ori_img = Image.open(img_path).convert('LA').convert('RGB') + def run_image(self, img): + if isinstance(img, str): + ori_img = Image.open(img).convert('LA').convert('RGB') + elif isinstance(img, np.ndarray): + ori_img = Image.fromarray(img).convert('LA').convert('RGB') + elif isinstance(img, Image.Image): + ori_img = img + img = self.norm(ori_img, self.render_factor) x = paddle.to_tensor(img[np.newaxis, ...]) out = self.model(x) @@ -89,8 +95,8 @@ class DeOldifyPredictor(BasePredictor): pred_img = self.post_process(pred_img, ori_img) return pred_img - def run(self, video_path): - base_name = os.path.basename(video_path).split('.')[0] + def run_video(self, video): + base_name = os.path.basename(video).split('.')[0] output_path = os.path.join(self.output, base_name) pred_frame_path = os.path.join(output_path, 'frames_pred') @@ -100,15 +106,15 @@ class DeOldifyPredictor(BasePredictor): if not os.path.exists(pred_frame_path): os.makedirs(pred_frame_path) - cap = cv2.VideoCapture(video_path) + cap = cv2.VideoCapture(video) fps = cap.get(cv2.CAP_PROP_FPS) - out_path = video2frames(video_path, output_path) + out_path = video2frames(video, output_path) frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) for frame in tqdm(frames): - pred_img = self.run_single(frame) + pred_img = self.run_image(frame) frame_name = os.path.basename(frame) pred_img.save(os.path.join(pred_frame_path, frame_name)) @@ -120,3 +126,15 @@ class DeOldifyPredictor(BasePredictor): frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) return frame_pattern_combined, vid_out_path + + def run(self, input): + if self.is_video(input): + return self.run_video(input) + else: + pred_img = self.run_image(input) + + if self.output: + base_name = os.path.basename(input) + pred_img.save(os.path.join(self.output, base_name + '.png')) + + return pred_img diff --git a/ppgan/apps/realsr_predictor.py b/ppgan/apps/realsr_predictor.py index 9f16170..4c11f49 100644 --- a/ppgan/apps/realsr_predictor.py +++ b/ppgan/apps/realsr_predictor.py @@ -49,8 +49,14 @@ class RealSRPredictor(BasePredictor): img = img.transpose((1, 2, 0)) return (img * 255).clip(0, 255).astype('uint8') - def run_single(self, img_path): - ori_img = Image.open(img_path).convert('RGB') + def run_image(self, img): + if isinstance(img, str): + ori_img = Image.open(img).convert('RGB') + elif isinstance(img, np.ndarray): + ori_img = Image.fromarray(img).convert('RGB') + elif isinstance(img, Image.Image): + ori_img = img + img = self.norm(ori_img) x = paddle.to_tensor(img[np.newaxis, ...]) out = self.model(x) @@ -59,9 +65,8 @@ class RealSRPredictor(BasePredictor): pred_img = Image.fromarray(pred_img) return pred_img - def run(self, video_path): - vid = video_path - base_name = os.path.basename(vid).split('.')[0] + def run_video(self, video): + base_name = os.path.basename(video).split('.')[0] output_path = os.path.join(self.output, base_name) pred_frame_path = os.path.join(output_path, 'frames_pred') @@ -71,15 +76,15 @@ class RealSRPredictor(BasePredictor): if not os.path.exists(pred_frame_path): os.makedirs(pred_frame_path) - cap = cv2.VideoCapture(vid) + cap = cv2.VideoCapture(video) fps = cap.get(cv2.CAP_PROP_FPS) - out_path = video2frames(vid, output_path) + out_path = video2frames(video, output_path) frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) for frame in tqdm(frames): - pred_img = self.run_single(frame) + pred_img = self.run_image(frame) frame_name = os.path.basename(frame) pred_img.save(os.path.join(pred_frame_path, frame_name)) @@ -91,3 +96,15 @@ class RealSRPredictor(BasePredictor): frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) return frame_pattern_combined, vid_out_path + + def run(self, input): + if self.is_video(input): + return self.run_video(input) + else: + pred_img = self.run_image(input) + + if self.output: + base_name = os.path.basename(input) + pred_img.save(os.path.join(self.output, base_name + '.png')) + + return pred_img -- GitLab