module.py 1.7 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
import os

from paddlehub import Module
from paddlehub.module.module import moduleinfo, serving

from UGATIT_92w.model import Model
from UGATIT_92w.processor import base64_to_cv2, cv2_to_base64, Processor


@moduleinfo(
    name="UGATIT_92w",  # 模型名称
    type="CV/style_transfer",  # 模型类型
    author="jm12138",  # 作者名称
    author_email="jm12138@qq.com",  # 作者邮箱
    summary="UGATIT_92w",  # 模型介绍
jm_12138's avatar
jm_12138 已提交
16
    version="1.0.1"  # 版本号
W
wuzewu 已提交
17 18 19
)
class UGATIT_92w(Module):
    # 初始化函数
jm_12138's avatar
jm_12138 已提交
20
    def __init__(self, name=None, use_gpu=False):
W
wuzewu 已提交
21 22 23 24
        # 设置模型路径
        self.model_path = os.path.join(self.directory, "UGATIT_92w")

        # 加载模型
jm_12138's avatar
jm_12138 已提交
25
        self.model = Model(modelpath=self.model_path, use_gpu=use_gpu, use_mkldnn=False, combined=False)
W
wuzewu 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

    # 关键点检测函数
    def style_transfer(self, images=None, paths=None, batch_size=1, output_dir='output', visualization=False):
        # 加载数据处理器
        processor = Processor(images, paths, output_dir, batch_size)

        # 模型预测
        outputs = self.model.predict(processor.input_datas)

        # 结果后处理
        results = processor.postprocess(outputs, visualization)

        # 返回结果
        return results

    # Hub Serving
    @serving
    def serving_method(self, images, **kwargs):
        # 获取输入数据
        images_decode = [base64_to_cv2(image) for image in images]

        # 图片风格转换
        results = self.style_transfer(images_decode, **kwargs)

        # 对输出图片进行编码
        encodes = []
        for result in results:
            encode = cv2_to_base64(result)
            encodes.append(encode)

        # 返回结果
        return encodes