diff --git a/test/test_utils.py b/test/test_utils.py index e35c0326766ce395d66b5e6e9c726c4dbcf4f512..a5070de68872a2894e9bdd2186c6f302302af725 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -20,6 +20,7 @@ import sys import tarfile import zipfile import platform +import functools lasttime = time.time() FLUSH_INTERVAL = 0.1 @@ -78,8 +79,10 @@ def _uncompress_file(filepath, extrapath, delete_file, print_progress): if filepath.endswith("zip"): handler = _uncompress_file_zip - else: + elif filepath.endswith("tgz"): handler = _uncompress_file_tar + else: + handler = functools.partial(_uncompress_file_tar, mode="r") for total_num, index in handler(filepath, extrapath): if print_progress: @@ -104,8 +107,8 @@ def _uncompress_file_zip(filepath, extrapath): yield total_num, index -def _uncompress_file_tar(filepath, extrapath): - files = tarfile.open(filepath, "r:gz") +def _uncompress_file_tar(filepath, extrapath, mode="r:gz"): + files = tarfile.open(filepath, mode) filelist = files.getnames() total_num = len(filelist) for index, file in enumerate(filelist):