utils.py 4.2 KB
Newer Older
K
KP 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
K
KP 已提交
15 16
import tarfile
import zipfile
K
KP 已提交
17 18 19 20 21
from typing import Any
from typing import Dict

from paddle.framework import load

小湉湉's avatar
小湉湉 已提交
22
from . import download
K
KP 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
from .entry import commands

__all__ = [
    'cli_register',
    'get_command',
    'download_and_decompress',
    'load_state_dict_from_url',
]


def cli_register(name: str, description: str='') -> Any:
    def _warpper(command):
        items = name.split('.')

        com = commands
        for item in items:
            com = com[item]
        com['_entry'] = command
        if description:
            com['_description'] = description
        return command

    return _warpper


def get_command(name: str) -> Any:
    items = name.split('.')
    com = commands
    for item in items:
        com = com[item]

    return com['_entry']


K
KP 已提交
57 58 59 60 61 62 63 64 65 66
def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike:
    file_dir = os.path.dirname(filepath)
    if tarfile.is_tarfile(filepath):
        files = tarfile.open(filepath, "r:*")
        file_list = files.getnames()
    elif zipfile.is_zipfile(filepath):
        files = zipfile.ZipFile(filepath, 'r')
        file_list = files.namelist()
    else:
        return file_dir
K
KP 已提交
67 68

    if download._is_a_single_file(file_list):
K
KP 已提交
69 70
        rootpath = file_list[0]
        uncompressed_path = os.path.join(file_dir, rootpath)
K
KP 已提交
71 72
    elif download._is_a_single_dir(file_list):
        rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
K
KP 已提交
73 74 75 76 77 78 79 80 81
        uncompressed_path = os.path.join(file_dir, rootpath)
    else:
        rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
        uncompressed_path = os.path.join(file_dir, rootpath)

    files.close()
    return uncompressed_path


K
KP 已提交
82
def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike:
K
KP 已提交
83 84 85 86 87 88
    """
    Download archieves and decompress to specific path.
    """
    if not os.path.isdir(path):
        os.makedirs(path)

K
KP 已提交
89 90
    assert 'url' in archive and 'md5' in archive, \
        'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys()))
K
KP 已提交
91

K
KP 已提交
92
    filepath = os.path.join(path, os.path.basename(archive['url']))
K
KP 已提交
93 94
    if os.path.isfile(filepath) and download._md5check(filepath,
                                                       archive['md5']):
K
KP 已提交
95 96 97
        uncompress_path = _get_uncompress_path(filepath)
        if not os.path.isdir(uncompress_path):
            download._decompress(filepath)
K
KP 已提交
98
    else:
K
KP 已提交
99 100 101 102
        uncompress_path = download.get_path_from_url(archive['url'], path,
                                                     archive['md5'])

    return uncompress_path
K
KP 已提交
103 104


K
KP 已提交
105
def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike:
K
KP 已提交
106 107 108 109 110 111 112 113
    """
    Download and load a state dict from url
    """
    if not os.path.isdir(path):
        os.makedirs(path)

    download.get_path_from_url(url, path, md5)
    return load(os.path.join(path, os.path.basename(url)))
K
KP 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143


def _get_user_home():
    return os.path.expanduser('~')


def _get_paddlespcceh_home():
    if 'PPSPEECH_HOME' in os.environ:
        home_path = os.environ['PPSPEECH_HOME']
        if os.path.exists(home_path):
            if os.path.isdir(home_path):
                return home_path
            else:
                raise RuntimeError(
                    'The environment variable PPSPEECH_HOME {} is not a directory.'.
                    format(home_path))
        else:
            return home_path
    return os.path.join(_get_user_home(), '.paddlespeech')


def _get_sub_home(directory):
    home = os.path.join(_get_paddlespcceh_home(), directory)
    if not os.path.exists(home):
        os.makedirs(home)
    return home


PPSPEECH_HOME = _get_paddlespcceh_home()
MODEL_HOME = _get_sub_home('models')