animegan_predictor.py 3.4 KB
Newer Older
郑启航 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#  Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import numpy as np
import cv2

import paddle
from .base_predictor import BasePredictor
21
from ppgan.datasets.preprocess.transforms import ResizeToScale
郑启航 已提交
22 23 24 25 26 27 28
import paddle.vision.transforms as T
from ppgan.models.generators import AnimeGenerator
from ppgan.utils.download import get_path_from_url


class AnimeGANPredictor(BasePredictor):
    def __init__(self,
Q
qingqing01 已提交
29
                 output_path='output',
郑启航 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
                 weight_path=None,
                 use_adjust_brightness=True):
        self.output_path = output_path
        self.input_size = (256, 256)
        self.use_adjust_brightness = use_adjust_brightness
        if weight_path is None:
            vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/models/animeganv2_hayao.pdparams'
            weight_path = get_path_from_url(vox_cpk_weight_url)
        self.weight_path = weight_path
        self.generator = self.load_checkpoints()
        self.transform = T.Compose([
            ResizeToScale((256, 256), 32),
            T.Transpose(),
            T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])
        ])

    def load_checkpoints(self):
        generator = AnimeGenerator()
        checkpoint = paddle.load(self.weight_path)
        generator.set_state_dict(checkpoint['netG'])
        generator.eval()
        return generator

    @staticmethod
    def calc_avg_brightness(img):
        R = img[..., 0].mean()
        G = img[..., 1].mean()
        B = img[..., 2].mean()

        brightness = 0.299 * R + 0.587 * G + 0.114 * B
        return brightness, B, G, R

    @staticmethod
    def adjust_brightness(dst, src):
        brightness1, B1, G1, R1 = AnimeGANPredictor.calc_avg_brightness(src)
        brightness2, B2, G2, R2 = AnimeGANPredictor.calc_avg_brightness(dst)
        brightness_difference = brightness1 / brightness2
        dstf = dst * brightness_difference
        dstf = np.clip(dstf, 0, 255)
        dstf = np.uint8(dstf)
        return dstf

    def run(self, image):
        image = cv2.cvtColor(cv2.imread(image, flags=cv2.IMREAD_COLOR),
                             cv2.COLOR_BGR2RGB)
        transformed_image = self.transform(image)
        anime = (self.generator(paddle.to_tensor(transformed_image[None, ...]))
                 * 0.5 + 0.5)[0].numpy() * 255
        anime = anime.transpose((1, 2, 0))
        if anime.shape[:2] != image.shape[:2]:
            # to original size
            anime = T.resize(anime, image.shape[:2])
        if self.use_adjust_brightness:
            anime = self.adjust_brightness(anime, image)
        else:
            anime = anime.astype('uint8')
        if not os.path.exists(self.output_path):
            os.makedirs(self.output_path)
        save_path = os.path.join(self.output_path, 'anime.png')
        cv2.imwrite(save_path, cv2.cvtColor(anime, cv2.COLOR_RGB2BGR))
        return image