提交 1744b3a6 编写于 作者: L LielinJiang

add run image

上级 30159cf1
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#limitations under the License. #limitations under the License.
import os import os
import cv2
import paddle import paddle
...@@ -60,11 +61,12 @@ class BasePredictor(object): ...@@ -60,11 +61,12 @@ class BasePredictor(object):
return out return out
def preprocess(self, inputs): def is_video(self, input):
pass try:
cv2.VideoCapture(input)
def postprocess(self, inputs): return True
pass except:
return False
def run(self): def run(self):
pass raise NotImplementedError
...@@ -77,8 +77,14 @@ class DeOldifyPredictor(BasePredictor): ...@@ -77,8 +77,14 @@ class DeOldifyPredictor(BasePredictor):
final = Image.fromarray(final) final = Image.fromarray(final)
return final return final
def run_single(self, img_path): def run_image(self, img):
ori_img = Image.open(img_path).convert('LA').convert('RGB') if isinstance(img, str):
ori_img = Image.open(img).convert('LA').convert('RGB')
elif isinstance(img, np.ndarray):
ori_img = Image.fromarray(img).convert('LA').convert('RGB')
elif isinstance(img, Image.Image):
ori_img = img
img = self.norm(ori_img, self.render_factor) img = self.norm(ori_img, self.render_factor)
x = paddle.to_tensor(img[np.newaxis, ...]) x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x) out = self.model(x)
...@@ -89,8 +95,8 @@ class DeOldifyPredictor(BasePredictor): ...@@ -89,8 +95,8 @@ class DeOldifyPredictor(BasePredictor):
pred_img = self.post_process(pred_img, ori_img) pred_img = self.post_process(pred_img, ori_img)
return pred_img return pred_img
def run(self, video_path): def run_video(self, video):
base_name = os.path.basename(video_path).split('.')[0] base_name = os.path.basename(video).split('.')[0]
output_path = os.path.join(self.output, base_name) output_path = os.path.join(self.output, base_name)
pred_frame_path = os.path.join(output_path, 'frames_pred') pred_frame_path = os.path.join(output_path, 'frames_pred')
...@@ -100,15 +106,15 @@ class DeOldifyPredictor(BasePredictor): ...@@ -100,15 +106,15 @@ class DeOldifyPredictor(BasePredictor):
if not os.path.exists(pred_frame_path): if not os.path.exists(pred_frame_path):
os.makedirs(pred_frame_path) os.makedirs(pred_frame_path)
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video)
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
out_path = video2frames(video_path, output_path) out_path = video2frames(video, output_path)
frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
for frame in tqdm(frames): for frame in tqdm(frames):
pred_img = self.run_single(frame) pred_img = self.run_image(frame)
frame_name = os.path.basename(frame) frame_name = os.path.basename(frame)
pred_img.save(os.path.join(pred_frame_path, frame_name)) pred_img.save(os.path.join(pred_frame_path, frame_name))
...@@ -120,3 +126,15 @@ class DeOldifyPredictor(BasePredictor): ...@@ -120,3 +126,15 @@ class DeOldifyPredictor(BasePredictor):
frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
def run(self, input):
if self.is_video(input):
return self.run_video(input)
else:
pred_img = self.run_image(input)
if self.output:
base_name = os.path.basename(input)
pred_img.save(os.path.join(self.output, base_name + '.png'))
return pred_img
...@@ -49,8 +49,14 @@ class RealSRPredictor(BasePredictor): ...@@ -49,8 +49,14 @@ class RealSRPredictor(BasePredictor):
img = img.transpose((1, 2, 0)) img = img.transpose((1, 2, 0))
return (img * 255).clip(0, 255).astype('uint8') return (img * 255).clip(0, 255).astype('uint8')
def run_single(self, img_path): def run_image(self, img):
ori_img = Image.open(img_path).convert('RGB') if isinstance(img, str):
ori_img = Image.open(img).convert('RGB')
elif isinstance(img, np.ndarray):
ori_img = Image.fromarray(img).convert('RGB')
elif isinstance(img, Image.Image):
ori_img = img
img = self.norm(ori_img) img = self.norm(ori_img)
x = paddle.to_tensor(img[np.newaxis, ...]) x = paddle.to_tensor(img[np.newaxis, ...])
out = self.model(x) out = self.model(x)
...@@ -59,9 +65,8 @@ class RealSRPredictor(BasePredictor): ...@@ -59,9 +65,8 @@ class RealSRPredictor(BasePredictor):
pred_img = Image.fromarray(pred_img) pred_img = Image.fromarray(pred_img)
return pred_img return pred_img
def run(self, video_path): def run_video(self, video):
vid = video_path base_name = os.path.basename(video).split('.')[0]
base_name = os.path.basename(vid).split('.')[0]
output_path = os.path.join(self.output, base_name) output_path = os.path.join(self.output, base_name)
pred_frame_path = os.path.join(output_path, 'frames_pred') pred_frame_path = os.path.join(output_path, 'frames_pred')
...@@ -71,15 +76,15 @@ class RealSRPredictor(BasePredictor): ...@@ -71,15 +76,15 @@ class RealSRPredictor(BasePredictor):
if not os.path.exists(pred_frame_path): if not os.path.exists(pred_frame_path):
os.makedirs(pred_frame_path) os.makedirs(pred_frame_path)
cap = cv2.VideoCapture(vid) cap = cv2.VideoCapture(video)
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
out_path = video2frames(vid, output_path) out_path = video2frames(video, output_path)
frames = sorted(glob.glob(os.path.join(out_path, '*.png'))) frames = sorted(glob.glob(os.path.join(out_path, '*.png')))
for frame in tqdm(frames): for frame in tqdm(frames):
pred_img = self.run_single(frame) pred_img = self.run_image(frame)
frame_name = os.path.basename(frame) frame_name = os.path.basename(frame)
pred_img.save(os.path.join(pred_frame_path, frame_name)) pred_img.save(os.path.join(pred_frame_path, frame_name))
...@@ -91,3 +96,15 @@ class RealSRPredictor(BasePredictor): ...@@ -91,3 +96,15 @@ class RealSRPredictor(BasePredictor):
frames2video(frame_pattern_combined, vid_out_path, str(int(fps))) frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
def run(self, input):
if self.is_video(input):
return self.run_video(input)
else:
pred_img = self.run_image(input)
if self.output:
base_name = os.path.basename(input)
pred_img.save(os.path.join(self.output, base_name + '.png'))
return pred_img
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册