predict.py 4.8 KB
Newer Older
L
LielinJiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
import os
import sys

cur_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(cur_path)

import cv2
import glob
import argparse
import numpy as np
import paddle
import pickle

from PIL import Image
from tqdm import tqdm
from paddle import fluid
L
LielinJiang 已提交
17
from paddle.utils.download import get_path_from_url
L
LielinJiang 已提交
18 19 20
from ppgan.utils.video import frames2video, video2frames
from ppgan.models.generators.deoldify import build_model

L
LielinJiang 已提交
21
parser = argparse.ArgumentParser(description='DeOldify')
L
LielinJiang 已提交
22 23
parser.add_argument('--input', type=str, default='none', help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir')
L
LielinJiang 已提交
24 25 26 27
parser.add_argument('--render_factor',
                    type=int,
                    default=32,
                    help='model inputsize=render_factor*16')
L
LielinJiang 已提交
28 29
parser.add_argument('--weight_path',
                    type=str,
L
LielinJiang 已提交
30
                    default=None,
L
LielinJiang 已提交
31
                    help='Path to the reference image directory')
L
LielinJiang 已提交
32

L
LielinJiang 已提交
33
DEOLDIFY_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
L
LielinJiang 已提交
34

L
LielinJiang 已提交
35

L
LielinJiang 已提交
36
class DeOldifyPredictor():
L
LielinJiang 已提交
37 38 39 40 41 42
    def __init__(self,
                 input,
                 output,
                 batch_size=1,
                 weight_path=None,
                 render_factor=32):
L
LielinJiang 已提交
43 44
        self.input = input
        self.output = os.path.join(output, 'DeOldify')
L
LielinJiang 已提交
45
        self.render_factor = render_factor
L
LielinJiang 已提交
46 47
        self.model = build_model()
        if weight_path is None:
L
LielinJiang 已提交
48
            weight_path = get_path_from_url(DEOLDIFY_WEIGHT_URL, cur_path)
L
LielinJiang 已提交
49 50 51 52 53 54 55 56

        state_dict, _ = paddle.load(weight_path)
        self.model.load_dict(state_dict)
        self.model.eval()

    def norm(self, img, render_factor=32, render_base=16):
        target_size = render_factor * render_base
        img = img.resize((target_size, target_size), resample=Image.BILINEAR)
L
LielinJiang 已提交
57

L
LielinJiang 已提交
58
        img = np.array(img).transpose([2, 0, 1]).astype('float32') / 255.0
L
LielinJiang 已提交
59

L
LielinJiang 已提交
60 61 62 63 64 65 66 67 68 69
        img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
        img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))

        img -= img_mean
        img /= img_std
        return img.astype('float32')

    def denorm(self, img):
        img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
        img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
L
LielinJiang 已提交
70

L
LielinJiang 已提交
71 72 73 74
        img *= img_std
        img += img_mean
        img = img.transpose((1, 2, 0))

L
LielinJiang 已提交
75
        return (img * 255).clip(0, 255).astype('uint8')
L
LielinJiang 已提交
76

L
LielinJiang 已提交
77 78 79 80 81 82 83 84 85 86
    def post_process(self, raw_color, orig):
        color_np = np.asarray(raw_color)
        orig_np = np.asarray(orig)
        color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
        orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
        hires = np.copy(orig_yuv)
        hires[:, :, 1:3] = color_yuv[:, :, 1:3]
        final = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)
        final = Image.fromarray(final)
        return final
L
LielinJiang 已提交
87

L
LielinJiang 已提交
88 89
    def run_single(self, img_path):
        ori_img = Image.open(img_path).convert('LA').convert('RGB')
L
LielinJiang 已提交
90
        img = self.norm(ori_img, self.render_factor)
L
LielinJiang 已提交
91
        x = paddle.to_tensor(img[np.newaxis, ...])
L
LielinJiang 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        out = self.model(x)

        pred_img = self.denorm(out.numpy()[0])
        pred_img = Image.fromarray(pred_img)
        pred_img = pred_img.resize(ori_img.size, resample=Image.BILINEAR)
        pred_img = self.post_process(pred_img, ori_img)
        return pred_img

    def run(self):
        vid = self.input
        base_name = os.path.basename(vid).split('.')[0]
        output_path = os.path.join(self.output, base_name)
        pred_frame_path = os.path.join(output_path, 'frames_pred')

        if not os.path.exists(output_path):
            os.makedirs(output_path)

        if not os.path.exists(pred_frame_path):
            os.makedirs(pred_frame_path)

        cap = cv2.VideoCapture(vid)
        fps = cap.get(cv2.CAP_PROP_FPS)

L
LielinJiang 已提交
115
        out_path = video2frames(vid, output_path)
L
LielinJiang 已提交
116 117 118 119 120 121 122 123

        frames = sorted(glob.glob(os.path.join(out_path, '*.png')))

        for frame in tqdm(frames):
            pred_img = self.run_single(frame)

            frame_name = os.path.basename(frame)
            pred_img.save(os.path.join(pred_frame_path, frame_name))
L
LielinJiang 已提交
124

L
LielinJiang 已提交
125 126
        frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')

L
LielinJiang 已提交
127 128
        vid_out_path = os.path.join(output_path,
                                    '{}_deoldify_out.mp4'.format(base_name))
L
LielinJiang 已提交
129
        frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
L
LielinJiang 已提交
130

L
LielinJiang 已提交
131
        return frame_pattern_combined, vid_out_path
L
LielinJiang 已提交
132 133


L
LielinJiang 已提交
134
if __name__ == '__main__':
L
LielinJiang 已提交
135
    paddle.disable_static()
L
LielinJiang 已提交
136 137
    args = parser.parse_args()

L
LielinJiang 已提交
138 139
    predictor = DeOldifyPredictor(args.input,
                                  args.output,
L
LielinJiang 已提交
140 141
                                  weight_path=args.weight_path,
                                  render_factor=args.render_factor)
L
LielinJiang 已提交
142 143
    frames_path, temp_video_path = predictor.run()

L
LielinJiang 已提交
144
    print('output video path:', temp_video_path)