checker.py 7.4 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15

W
wuzewu 已提交
16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
W
wuzewu 已提交
19

W
wuzewu 已提交
20
import os
W
wuzewu 已提交
21

W
wuzewu 已提交
22 23
import paddle

W
wuzewu 已提交
24
from paddlehub.common.logger import logger
W
wuzewu 已提交
25
from paddlehub.common.utils import version_compare
W
wuzewu 已提交
26 27 28
from paddlehub.module import check_info_pb2
from paddlehub.version import hub_version, module_proto_version

W
wuzewu 已提交
29 30 31 32 33
# check info
CHECK_INFO_PB_FILENAME = "check_info.pb"
FILE_SEP = "/"


W
wuzewu 已提交
34
class ModuleChecker(object):
W
wuzewu 已提交
35 36 37
    def __init__(self, directory):
        self._directory = directory
        self._pb_path = os.path.join(self.directory, CHECK_INFO_PB_FILENAME)
W
wuzewu 已提交
38 39 40 41 42 43

    def generate_check_info(self):
        check_info = check_info_pb2.CheckInfo()
        check_info.paddle_version = paddle.__version__
        check_info.hub_version = hub_version
        check_info.module_proto_version = module_proto_version
W
wuzewu 已提交
44
        check_info.module_code_version = "v2"
W
wuzewu 已提交
45
        file_infos = check_info.file_infos
W
wuzewu 已提交
46
        file_list = [file for file in os.listdir(self.directory)]
W
wuzewu 已提交
47 48 49
        while file_list:
            file = file_list[0]
            file_list = file_list[1:]
W
wuzewu 已提交
50
            abs_path = os.path.join(self.directory, file)
W
wuzewu 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
            if os.path.isdir(abs_path):
                for sub_file in os.listdir(abs_path):
                    sub_file = os.path.join(file, sub_file)
                    file_list.append(sub_file)
                file_info = file_infos.add()
                file_info.file_name = file
                file.replace(os.sep, FILE_SEP)
                file_info.type = check_info_pb2.DIR
                file_info.is_need = True
            else:
                file.replace(os.sep, FILE_SEP)
                file_info = file_infos.add()
                file_info.file_name = file
                file_info.type = check_info_pb2.FILE
                file_info.is_need = True

W
wuzewu 已提交
67 68 69 70 71 72
        with open(self.pb_path, "wb") as file:
            file.write(check_info.SerializeToString())

    @property
    def module_code_version(self):
        return self.check_info.module_code_version
W
wuzewu 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

    @property
    def module_proto_version(self):
        return self.check_info.module_proto_version

    @property
    def hub_version(self):
        return self.check_info.hub_version

    @property
    def paddle_version(self):
        return self.check_info.paddle_version

    @property
    def file_infos(self):
        return self.check_info.file_infos

W
wuzewu 已提交
90 91 92 93 94 95 96 97
    @property
    def directory(self):
        return self._directory

    @property
    def pb_path(self):
        return self._pb_path

W
wuzewu 已提交
98
    def check(self):
W
wuzewu 已提交
99
        result = True
W
wuzewu 已提交
100

W
wuzewu 已提交
101
        if not (os.path.exists(self.pb_path) or os.path.isfile(self.pb_path)):
W
wuzewu 已提交
102 103 104
            logger.warning(
                "This module lacks core file %s" % CHECK_INFO_PB_FILENAME)
            result = False
W
wuzewu 已提交
105 106 107

        self.check_info = check_info_pb2.CheckInfo()
        try:
W
wuzewu 已提交
108
            with open(self.pb_path, "rb") as fi:
W
wuzewu 已提交
109 110 111 112
                pb_string = fi.read()
                result = self.check_info.ParseFromString(pb_string)
                if len(pb_string) == 0 or (result is not None
                                           and result != len(pb_string)):
W
wuzewu 已提交
113 114 115
                    logger.warning(
                        "File [%s] is incomplete" % CHECK_INFO_PB_FILENAME)
                    result = False
W
wuzewu 已提交
116
        except Exception as e:
W
wuzewu 已提交
117
            result = False
W
wuzewu 已提交
118 119

        if not self.check_info.paddle_version:
W
wuzewu 已提交
120 121 122
            logger.warning("Unable to read paddle version from [%s]" %
                           CHECK_INFO_PB_FILENAME)
            result = False
W
wuzewu 已提交
123 124

        if not self.check_info.hub_version:
W
wuzewu 已提交
125 126 127
            logger.warning(
                "Unable to read hub version from [%s]" % CHECK_INFO_PB_FILENAME)
            result = False
W
wuzewu 已提交
128 129

        if not self.check_info.module_proto_version:
W
wuzewu 已提交
130 131 132
            logger.warning("Unable to read module pb version from [%s]" %
                           CHECK_INFO_PB_FILENAME)
            result = False
W
wuzewu 已提交
133 134

        if not self.check_info.file_infos:
W
wuzewu 已提交
135 136 137 138 139 140 141 142 143
            logger.warning(
                "Unable to read file info from [%s]" % CHECK_INFO_PB_FILENAME)
            result = False

        if not self.check_module():
            result = False

        if not self.check_compatibility():
            result = False
W
wuzewu 已提交
144

W
wuzewu 已提交
145
        return result
W
wuzewu 已提交
146 147

    def check_compatibility(self):
W
wuzewu 已提交
148 149 150 151 152 153 154 155 156 157 158
        result = True
        if not self._check_module_proto_version():
            result = False

        if not self._check_hub_version():
            result = False

        if not self._check_paddle_version():
            result = False

        return result
W
wuzewu 已提交
159 160 161 162 163 164 165 166 167

    def check_module(self):
        return self._check_module_integrity() and self._check_dependency()

    def _check_dependency(self):
        return True

    def _check_module_proto_version(self):
        if self.module_proto_version != module_proto_version:
W
wuzewu 已提交
168
            logger.warning(
169
                "Module description file version cannot be aligned with PaddleHub version"
W
wuzewu 已提交
170
            )
W
wuzewu 已提交
171 172 173 174
            return False
        return True

    def _check_hub_version(self):
W
wuzewu 已提交
175 176
        if version_compare(self.hub_version, hub_version):
            logger.warning(
177
                "This Module is generated by the PaddleHub with version %s, and the local PaddleHub version is %s, which may cause serious incompatible bug. Please upgrade PaddleHub to the latest version."
W
wuzewu 已提交
178 179
                % (self.hub_version, hub_version))
            return False
W
wuzewu 已提交
180 181 182
        return True

    def _check_paddle_version(self):
W
wuzewu 已提交
183 184
        if version_compare(self.paddle_version, paddle.__version__):
            logger.warning(
185
                "This Module is generated by the PaddlePaddle with version %s, and the local PaddlePaddle version is %s, which may cause serious incompatible bug. Please upgrade PaddlePaddle to the latest version."
W
wuzewu 已提交
186 187
                % (self.paddle_version, paddle.__version__))
            return False
W
wuzewu 已提交
188 189 190
        return True

    def _check_module_integrity(self):
W
wuzewu 已提交
191
        result = True
W
wuzewu 已提交
192 193 194
        for file_info in self.file_infos:
            file_type = file_info.type
            file_path = file_info.file_name.replace(FILE_SEP, os.sep)
W
wuzewu 已提交
195
            file_path = os.path.join(self.directory, file_path)
W
wuzewu 已提交
196 197
            if not os.path.exists(file_path):
                if file_info.is_need:
W
wuzewu 已提交
198 199 200 201
                    logger.warning(
                        "Module integrity check failed! Missing file [%s]" %
                        file_path)
                    result = False
W
wuzewu 已提交
202 203 204
            else:
                if file_type == check_info_pb2.FILE:
                    if not os.path.isfile(file_path):
W
wuzewu 已提交
205 206
                        logger.warning("File type check error %s" % file_path)
                        result = False
W
wuzewu 已提交
207 208 209

                if file_type == check_info_pb2.DIR:
                    if not os.path.isdir(file_path):
W
wuzewu 已提交
210 211 212
                        logger.warning("File type check error %s" % file_path)
                        result = False
        return result