xarfile.py 9.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tarfile
import zipfile
from typing import Callable
from typing import Generator
from typing import List

import rarfile


class XarInfo(object):
    '''Informational class which holds the details about an archive member given by a XarFile.'''

    def __init__(self, _xarinfo, arctype='tar'):
        self._info = _xarinfo
        self.arctype = arctype

    @property
    def name(self) -> str:
        if self.arctype == 'tar':
            return self._info.name
        return self._info.filename

    @property
    def size(self) -> int:
        if self.arctype == 'tar':
            return self._info.size
        return self._info.file_size


class XarFile(object):
    '''
    The XarFile Class provides an interface to tar/rar/zip archives.

    Args:
        name(str) : file or directory name to be archived
        mode(str) : specifies the mode in which the file is opened, it must be:
            ========   ==============================================================================================
            Charater   Meaning
            --------   ----------------------------------------------------------------------------------------------
            'r'        open for reading
            'w'        open for writing, truncating the file first, file will be saved according to the arctype field
            'a'        open for writing, appending to the end of the file if it exists
            ========   ===============================================================================================
        arctype(str) : archive type, support ['tar' 'rar' 'zip' 'tar.gz' 'tar.bz2' 'tar.xz' 'tgz' 'txz'], if
                       the mode if 'w' or 'a', the default is 'tar', if the mode is 'r', it will be based on actual
                       archive type of file
    '''

    def __init__(self, name: str, mode: str, arctype: str = 'tar', **kwargs):
        # if mode is 'w', adjust mode according to arctype field
        if mode == 'w':
            if arctype in ['tar.gz', 'tgz']:
                mode = 'w:gz'
                self.arctype = 'tar'
            elif arctype == 'tar.bz2':
                mode = 'w:bz2'
                self.arctype = 'tar'
            elif arctype in ['tar.xz', 'txz']:
                mode = 'w:xz'
                self.arctype = 'tar'
            else:
                self.arctype = arctype
        # if mode is 'r', adjust mode according to actual archive type of file
        elif mode == 'r':
            if tarfile.is_tarfile(name):
                self.arctype = 'tar'
                mode = 'r:*'
            elif zipfile.is_zipfile(name):
                self.arctype = 'zip'
            elif rarfile.is_rarfile(name):
                self.arctype = 'rar'
        elif mode == 'a':
            self.arctype = arctype
        else:
            raise RuntimeError('Unsupported mode {}'.format(mode))

        if self.arctype in [
                'tar.gz', 'tar.bz2', 'tar.xz', 'tar', 'tgz', 'txz'
        ]:
            self._archive_fp = tarfile.open(name, mode, **kwargs)
        elif self.arctype == 'zip':
            self._archive_fp = zipfile.ZipFile(name, mode, **kwargs)
        elif self.arctype == 'rar':
            self._archive_fp = rarfile.RarFile(name, mode, **kwargs)
        else:
            raise RuntimeError('Unsupported archive type {}'.format(
                self.arctype))

    def __del__(self):
        self._archive_fp.close()

    def __enter__(self):
        return self

    def __exit__(self, exit_exception, exit_value, exit_traceback):
        if exit_exception:
            print(exit_traceback)
            raise exit_exception(exit_value)
        self._archive_fp.close()
        return self

    def add(self,
            name: str,
            arcname: str = None,
            recursive: bool = True,
            exclude: Callable = None):
        '''
        Add the file `name' to the archive. `name' may be any type of file (directory, fifo, symbolic link, etc.).
        If given, `arcname' specifies an alternative name for the file in the archive. Directories are added
        recursively by default. This can be avoided by setting `recursive' to False. `exclude' is a function that
        should return True for each filename to be excluded.
        '''
        if self.arctype == 'tar':
            self._archive_fp.add(name, arcname, recursive, filter=exclude)
        else:
            self._archive_fp.write(name)
            if not recursive or not os.path.isdir(name):
                return
            items = []
            for _d, _sub_ds, _files in os.walk(name):
                items += [os.path.join(_d, _file) for _file in _files]
                items += [os.path.join(_d, _sub_d) for _sub_d in _sub_ds]

            for item in items:
                if exclude and not exclude(item):
                    continue
                self._archive_fp.write(item)

    def extract(self, name: str, path: str):
        '''Extract a file from the archive to the specified path.'''
        return self._archive_fp.extract(name, path)

    def extractall(self, path: str):
        '''Extract all files from the archive to the specified path.'''
        return self._archive_fp.extractall(path)

    def getnames(self) -> List[str]:
        '''Return a list of file names in the archive.'''
        if self.arctype == 'tar':
            return self._archive_fp.getnames()
        return self._archive_fp.namelist()

    def getxarinfo(self, name: str) -> List[XarInfo]:
        '''Return the instance of XarInfo given 'name'.'''
        if self.arctype == 'tar':
            return XarInfo(self._archive_fp.getmember(name), self.arctype)
        return XarInfo(self._archive_fp.getinfo(name), self.arctype)


def open(name: str, mode: str = 'w', **kwargs) -> XarFile:
    '''
    Open a xar archive for reading, writing or appending. Return
    an appropriate XarFile class.
    '''
    return XarFile(name, mode, **kwargs)


def archive(filename: str,
            recursive: bool = True,
            exclude: Callable = None,
            arctype: str = 'tar') -> str:
    '''
    Archive a file or directory

    Args:
        name(str) : file or directory path to be archived
        recursive(bool) : whether to recursively archive directories
        exclude(Callable) : function that should return True for each filename to be excluded
        arctype(str) : archive type, support ['tar' 'rar' 'zip' 'tar.gz' 'tar.bz2' 'tar.xz' 'tgz' 'txz']

    Returns:
        str: archived file path

    Examples:
        .. code-block:: python

            archive_path = '/PATH/TO/FILE'
            archive(archive_path, arcname='output.tar.gz', arctype='tar.gz')
    '''
    basename = os.path.splitext(os.path.basename(filename))[0]
    savename = '{}.{}'.format(basename, arctype)
    with open(savename, mode='w', arctype=arctype) as file:
        file.add(filename, recursive=recursive, exclude=exclude)

    return savename


def unarchive(name: str, path: str):
    '''
    Unarchive a file

    Args:
        name(str) : file or directory name to be unarchived
        path(str) : storage name of archive file

    Examples:
        .. code-block:: python

            unarchive_path = '/PATH/TO/FILE'
            unarchive(unarchive_path, path='./output')
    '''
    with open(name, mode='r') as file:
        file.extractall(path)


def unarchive_with_progress(name: str, path: str) -> Generator[str, int, int]:
    '''
    Unarchive a file and return the unarchiving progress -> Generator[filename, extrace_size, total_size]

    Args:
        name(str) : file or directory name to be unarchived
        path(str) : storage name of archive file

    Examples:
        .. code-block:: python

            unarchive_path = 'test.tar.gz'
            for filename, extract_size, total_szie in unarchive_with_progress(unarchive_path, path='./output'):
                print(filename, extract_size, total_size)
    '''
    with open(name, mode='r') as file:
        total_size = extract_size = 0
        for filename in file.getnames():
            total_size += file.getxarinfo(filename).size

        for filename in file.getnames():
            file.extract(filename, path)
            extract_size += file.getxarinfo(filename).size
            yield filename, extract_size, total_size


def is_xarfile(file: str) -> bool:
    '''Return True if xarfile supports specific file, otherwise False'''
    _x_func = [zipfile.is_zipfile, tarfile.is_tarfile, rarfile.is_rarfile]
    for _f in _x_func:
        if _f(file):
            return True
    return False