utils.py 11.2 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# coding:utf-8
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import contextlib
import cv2
W
wuzewu 已提交
19
import hashlib
W
wuzewu 已提交
20
import importlib
W
wuzewu 已提交
21 22
import math
import os
23
import requests
走神的阿圆's avatar
走神的阿圆 已提交
24
import socket
W
wuzewu 已提交
25 26 27
import sys
import time
import tempfile
W
wuzewu 已提交
28
import traceback
W
wuzewu 已提交
29
import types
30
from typing import Generator, List
W
wuzewu 已提交
31 32
from urllib.parse import urlparse

W
wuzewu 已提交
33
import numpy as np
W
wuzewu 已提交
34 35 36
import packaging.version

import paddlehub.env as hubenv
W
wuzewu 已提交
37
import paddlehub.utils as utils
W
wuzewu 已提交
38 39 40


class Version(packaging.version.Version):
W
wuzewu 已提交
41
    '''Extended implementation of packaging.version.Version'''
W
wuzewu 已提交
42

W
wuzewu 已提交
43 44 45 46 47 48 49 50
    def match(self, condition: str) -> bool:
        '''
        Determine whether the given condition are met
        Args:
            condition(str) : conditions for judgment
        Returns:
            bool: True if the given version condition are met, else False
        Examples:
W
wuzewu 已提交
51 52
            .. code-block:: python
                Version('1.2.0').match('>=1.2.0a')
W
wuzewu 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
        '''
        if not condition:
            return True
        if condition.startswith('>='):
            version = condition[2:]
            _comp = self.__ge__
        elif condition.startswith('>'):
            version = condition[1:]
            _comp = self.__gt__
        elif condition.startswith('<='):
            version = condition[2:]
            _comp = self.__le__
        elif condition.startswith('<'):
            version = condition[1:]
            _comp = self.__lt__
        elif condition.startswith('=='):
            version = condition[2:]
            _comp = self.__eq__
        elif condition.startswith('='):
            version = condition[1:]
            _comp = self.__eq__
        else:
            version = condition
            _comp = self.__eq__

        return _comp(Version(version))

W
wuzewu 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    def __lt__(self, other):
        if isinstance(other, str):
            other = Version(other)
        return super().__lt__(other)

    def __le__(self, other):
        if isinstance(other, str):
            other = Version(other)
        return super().__le__(other)

    def __gt__(self, other):
        if isinstance(other, str):
            other = Version(other)
        return super().__gt__(other)

    def __ge__(self, other):
        if isinstance(other, str):
            other = Version(other)
        return super().__ge__(other)

    def __eq__(self, other):
        if isinstance(other, str):
            other = Version(other)
        return super().__eq__(other)

W
wuzewu 已提交
105 106 107

class Timer(object):
    '''Calculate runing speed and estimated time of arrival(ETA)'''
W
wuzewu 已提交
108

W
wuzewu 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    def __init__(self, total_step: int):
        self.total_step = total_step
        self.last_start_step = 0
        self.current_step = 0
        self._is_running = True

    def start(self):
        self.last_time = time.time()
        self.start_time = time.time()

    def stop(self):
        self._is_running = False
        self.end_time = time.time()

    def count(self) -> int:
        if not self.current_step >= self.total_step:
            self.current_step += 1
        return self.current_step

    @property
    def timing(self) -> float:
        run_steps = self.current_step - self.last_start_step
        self.last_start_step = self.current_step
        time_used = time.time() - self.last_time
        self.last_time = time.time()
        return run_steps / time_used

    @property
    def is_running(self) -> bool:
        return self._is_running

    @property
    def eta(self) -> str:
        if not self.is_running:
            return '00:00:00'
        scale = self.total_step / self.current_step
        remaining_time = (time.time() - self.start_time) * scale
        return seconds_to_hms(remaining_time)


149
def seconds_to_hms(seconds: int) -> str:
W
wuzewu 已提交
150 151 152 153 154 155 156
    '''Convert the number of seconds to hh:mm:ss'''
    h = math.floor(seconds / 3600)
    m = math.floor((seconds - h * 3600) / 60)
    s = int(seconds - h * 3600 - m * 60)
    hms_str = '{:0>2}:{:0>2}:{:0>2}'.format(h, m, s)
    return hms_str

W
wuzewu 已提交
157

H
haoyuying 已提交
158 159 160
def cv2_to_base64(image: np.ndarray) -> str:
    data = cv2.imencode('.jpg', image)[1]
    return base64.b64encode(data.tostring()).decode('utf8')
W
wuzewu 已提交
161

W
wuzewu 已提交
162

163
def base64_to_cv2(b64str: str) -> np.ndarray:
W
wuzewu 已提交
164 165 166 167 168 169 170 171
    '''Convert a string in base64 format to cv2 data'''
    data = base64.b64decode(b64str.encode('utf8'))
    data = np.fromstring(data, np.uint8)
    data = cv2.imdecode(data, cv2.IMREAD_COLOR)
    return data


@contextlib.contextmanager
172
def generate_tempfile(directory: str = None, **kwargs):
W
wuzewu 已提交
173 174
    '''Generate a temporary file'''
    directory = hubenv.TMP_HOME if not directory else directory
175
    with tempfile.NamedTemporaryFile(dir=directory, **kwargs) as file:
W
wuzewu 已提交
176 177 178 179
        yield file


@contextlib.contextmanager
180
def generate_tempdir(directory: str = None, **kwargs):
W
wuzewu 已提交
181 182
    '''Generate a temporary directory'''
    directory = hubenv.TMP_HOME if not directory else directory
183
    with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
W
wuzewu 已提交
184 185 186
        yield _dir


187
def download(url: str, path: str = None) -> str:
W
wuzewu 已提交
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
    '''
    Download a file
    Args:
        url (str) : url to be downloaded
        path (str, optional) : path to store downloaded products, default is current work directory
    Examples:
        .. code-block:: python
            url = 'https://xxxxx.xx/xx.tar.gz'
            download(url, path='./output')
    '''
    for savename, _, _ in download_with_progress(url, path):
        ...
    return savename


def download_with_progress(url: str, path: str = None) -> Generator[str, int, int]:
    '''
    Download a file and return the downloading progress -> Generator[filename, download_size, total_size]
    Args:
        url (str) : url to be downloaded
        path (str, optional) : path to store downloaded products, default is current work directory
    Examples:
        .. code-block:: python
            url = 'https://xxxxx.xx/xx.tar.gz'
            for filename, download_size, total_szie in download_with_progress(url, path='./output'):
                print(filename, download_size, total_size)
    '''
    path = os.getcwd() if not path else path
    if not os.path.exists(path):
        os.makedirs(path)

    parse_result = urlparse(url)
    savename = parse_result.path.split('/')[-1]
    savename = os.path.join(path, savename)

    res = requests.get(url, stream=True)
    download_size = 0
    total_size = int(res.headers.get('content-length'))
    with open(savename, 'wb') as _file:
        for data in res.iter_content(chunk_size=4096):
            _file.write(data)
            download_size += len(data)
            yield savename, download_size, total_size
W
wuzewu 已提交
231 232 233 234 235


def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
    '''
    Load the specified python module.
236

W
wuzewu 已提交
237 238 239 240 241
    Args:
        python_path(str) : The directory where the python module is located
        py_module_name(str) : Module name to be loaded
    '''
    sys.path.insert(0, python_path)
W
wuzewu 已提交
242 243 244 245 246 247 248

    # Delete the cache module to avoid hazards. For example, when the user reinstalls a HubModule,
    # if the cache is not cleared, then what the user gets at this time is actually the HubModule
    # before uninstallation, this can cause some strange problems, e.g, fail to load model parameters.
    if py_module_name in sys.modules:
        sys.modules.pop(py_module_name)

W
wuzewu 已提交
249 250 251 252
    py_module = importlib.import_module(py_module_name)
    sys.path.pop(0)

    return py_module
W
wuzewu 已提交
253 254 255


def get_platform_default_encoding() -> str:
W
wuzewu 已提交
256
    '''Get the default encoding of the current platform.'''
W
wuzewu 已提交
257 258 259 260 261 262
    if utils.platform.is_windows():
        return 'gbk'
    return 'utf8'


def sys_stdin_encoding() -> str:
W
wuzewu 已提交
263
    '''Get the standary input stream default encoding.'''
W
wuzewu 已提交
264 265 266 267 268 269 270 271 272 273
    encoding = sys.stdin.encoding
    if encoding is None:
        encoding = sys.getdefaultencoding()

    if encoding is None:
        encoding = get_platform_default_encoding()
    return encoding


def sys_stdout_encoding() -> str:
W
wuzewu 已提交
274
    '''Get the standary output stream default encoding.'''
W
wuzewu 已提交
275 276 277 278 279 280 281
    encoding = sys.stdout.encoding
    if encoding is None:
        encoding = sys.getdefaultencoding()

    if encoding is None:
        encoding = get_platform_default_encoding()
    return encoding
W
wuzewu 已提交
282 283 284


def md5(text: str):
W
wuzewu 已提交
285
    '''Calculate the md5 value of the input text.'''
W
wuzewu 已提交
286 287
    md5code = hashlib.md5(text.encode())
    return md5code.hexdigest()
W
wuzewu 已提交
288 289 290


def record(msg: str) -> str:
W
wuzewu 已提交
291
    '''Record the specified text into the PaddleHub log file witch will be automatically stored according to date.'''
W
wuzewu 已提交
292
    logfile = get_record_file()
W
wuzewu 已提交
293 294 295 296 297 298 299
    with open(logfile, 'a') as file:
        file.write('=' * 50 + '\n')
        file.write('Record at ' + time.strftime('%Y-%m-%d %H:%M:%S') + '\n')
        file.write('=' * 50 + '\n')
        file.write(str(msg) + '\n' * 3)

    return logfile
W
wuzewu 已提交
300 301 302


def record_exception(msg: str) -> str:
W
wuzewu 已提交
303
    '''Record the current exception infomation into the PaddleHub log file witch will be automatically stored according to date.'''
W
wuzewu 已提交
304 305 306
    tb = traceback.format_exc()
    file = record(tb)
    utils.log.logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file))
W
wuzewu 已提交
307 308


W
wuzewu 已提交
309
def get_record_file() -> str:
W
wuzewu 已提交
310
    return os.path.join(hubenv.LOG_HOME, time.strftime('%Y%m%d.log'))
走神的阿圆's avatar
走神的阿圆 已提交
311 312


W
wuzewu 已提交
313
def is_port_occupied(ip: str, port: int) -> bool:
走神的阿圆's avatar
走神的阿圆 已提交
314 315 316 317 318 319 320 321 322 323
    '''
    Check if port os occupied.
    '''
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        s.connect((ip, int(port)))
        s.shutdown(2)
        return True
    except:
        return False
W
wuzewu 已提交
324 325 326 327 328 329


def mkdir(path: str):
    """The same as the shell command `mkdir -p`."""
    if not os.path.exists(path):
        os.makedirs(path)
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369


def reseg_token_label(tokenizer, tokens: List[str], labels: List[str] = None):
    '''
    Convert segments and labels of sequence labeling samples into tokens
    based on the vocab of tokenizer.
    '''
    if labels:
        if len(tokens) != len(labels):
            raise ValueError(
                "The length of tokens must be same with labels")
        ret_tokens = []
        ret_labels = []
        for token, label in zip(tokens, labels):
            sub_token = tokenizer(token)
            if len(sub_token) == 0:
                continue
            ret_tokens.extend(sub_token)
            ret_labels.append(label)
            if len(sub_token) < 2:
                continue
            sub_label = label
            if label.startswith("B-"):
                sub_label = "I-" + label[2:]
            ret_labels.extend([sub_label] * (len(sub_token) - 1))

        if len(ret_tokens) != len(ret_labels):
            raise ValueError(
                "The length of ret_tokens can't match with labels")
        return ret_tokens, ret_labels
    else:
        ret_tokens = []
        for token in tokens:
            sub_token = tokenizer(token)
            if len(sub_token) == 0:
                continue
            ret_tokens.extend(sub_token)
            if len(sub_token) < 2:
                continue
        return ret_tokens, None