module_v1.py 5.1 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
# coding:utf-8
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
from typing import Tuple, List

import paddle

from paddlehub.compat import paddle_utils
from paddlehub.compat.module import module_v1_utils
from paddlehub.utils import utils, log


class ModuleV1(object):
    '''
    '''
    def __init__(self, name: str = None, directory: str = None, version: str = None):
        if not directory:
            return

        self.directory = directory
        desc_file = os.path.join(directory, 'module_desc.pb')
        self.desc = module_v1_utils.convert_module_desc(desc_file)
        self._load_model()
        self._load_parameters()
        self._load_processor()
        self._load_assets()
        self._load_extra_info()
        self._load_signatures()

    def _load_processor(self):
        python_path = os.path.join(self.directory, 'python')
        processor_name = self.desc.processor_info
        self.processor = utils.load_py_module(python_path, processor_name)

    def _load_assets(self):
        assets_path = os.path.join(self.directory, 'assets')
        self.assets = []
        for file in os.listdir(assets_path):
            filepath = os.path.join(assets_path, file)
            self.assets.append(filepath)

    def _load_parameters(self):
        global_block = self.program.global_block()
        for param, attrs in self.desc.param_attrs.items():
            name = self.desc.name_prefix + param
            if not name in global_block.vars:
                continue

            var = global_block.vars[name]
            global_block.create_parameter(name=name,
                                          shape=var.shape,
                                          dtype=var.dtype,
                                          type=var.type,
                                          lod_level=var.lod_level,
                                          error_clip=var.error_clip,
                                          stop_gradient=var.stop_gradient,
                                          is_data=var.is_data,
                                          **attrs)

    def _load_extra_info(self):
        for key, value in self.desc.extra_info.items():
            self.__dict__['get_{}'.format(key)] = value

    def _load_signatures(self):
        for signature in self.desc.signatures:
            self.__dict__[signature] = functools.partial(self.__call__, signature=signature)

    def _load_model(self):
        model_path = os.path.join(self.directory, 'model')
        exe = paddle.static.Executor(paddle.CPUPlace())
        self.program, _, _ = paddle.io.load_inference_model(model_path, executor=exe)

        # Clear the callstack since it may leak the privacy of the creator.
        for block in self.program.blocks:
            for op in block.ops:
                if not 'op_callstack' in op.all_attrs():
                    continue
                op._set_attr('op_callstack', [''])

    def context(self, for_test: bool = False, trainable: bool = True) -> Tuple[dict, dict, paddle.static.Program]:
        '''
        '''
        program = self.program.clone(for_test=for_test)
        paddle_utils.remove_feed_fetch_op(program)

        # generate feed vars and fetch vars from signatures
        feed_dict = {}
        fetch_dict = {}
        for info in self.desc.signatures.values():
            for feed_var in info.feed_vars:
                paddle_var = program.global_block().vars[feed_var.name]
                feed_dict[feed_var.alias] = paddle_var

            for fetch_var in info.fetch_vars:
                paddle_var = program.global_block().vars[fetch_var.name]
                fetch_dict[fetch_var.alias] = paddle_var

        # record num parameters loaded by PaddleHub
        num_param_loaded = 0
        for param in program.all_parameters():
            num_param_loaded += 1
            param.trainable = trainable

        log.logger.info('{} pretrained paramaters loaded by PaddleHub'.format(num_param_loaded))

        return feed_dict, fetch_dict, program

    def __call__(self, signature, data, use_gpu: bool = False, batch_size: int = 1, **kwargs):
        '''
        '''
        ...

    @classmethod
    def get_py_requirements(cls) -> List[str]:
        return []

    @classmethod
    def load(cls, desc_file):
        desc = module_v1_utils.convert_module_desc(desc_file)

        cls.author = desc.module_info.author
        cls.author_email = desc.module_info.author_email
        cls.summary = desc.module_info.summary
        cls.type = desc.module_info.type
        cls.name = desc.module_info.name
        cls.version = utils.Version(desc.module_info.version)
        return cls