downloader.py 5.4 KB
Newer Older
S
Steffy-zxf 已提交
1
#coding:utf-8
W
wuzewu 已提交
2
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
Z
Zeyu Chen 已提交
3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Z
Zeyu Chen 已提交
15
# coding=utf-8
Z
Zeyu Chen 已提交
16

Z
Zeyu Chen 已提交
17 18 19 20
from __future__ import print_function
from __future__ import division
from __future__ import print_function

W
wuzewu 已提交
21
import shutil
Z
Zeyu Chen 已提交
22
import os
Z
Zeyu Chen 已提交
23
import sys
W
wuzewu 已提交
24
import time
Z
Zeyu Chen 已提交
25 26 27
import hashlib
import requests
import tarfile
W
wuzewu 已提交
28

W
wuzewu 已提交
29 30
from paddlehub.common import utils
from paddlehub.common.logger import logger
Z
Zeyu Chen 已提交
31

W
wuzewu 已提交
32
__all__ = ['Downloader', 'progress']
W
wuzewu 已提交
33
FLUSH_INTERVAL = 0.1
Z
Zeyu Chen 已提交
34

W
wuzewu 已提交
35
lasttime = time.time()
Z
Zeyu Chen 已提交
36

W
wuzewu 已提交
37 38 39 40 41 42 43 44 45 46

def progress(str, end=False):
    global lasttime
    if end:
        str += "\n"
        lasttime = 0
    if time.time() - lasttime >= FLUSH_INTERVAL:
        sys.stdout.write("\r%s" % str)
        lasttime = time.time()
        sys.stdout.flush()
Z
Zeyu Chen 已提交
47 48


W
wuzewu 已提交
49
class Downloader(object):
W
wuzewu 已提交
50 51 52 53 54
    def download_file(self,
                      url,
                      save_path,
                      save_name=None,
                      retry_limit=3,
W
wuzewu 已提交
55 56
                      print_progress=False,
                      replace=False):
W
wuzewu 已提交
57 58 59 60 61
        if not os.path.exists(save_path):
            utils.mkdir(save_path)
        save_name = url.split('/')[-1] if save_name is None else save_name
        file_name = os.path.join(save_path, save_name)
        retry_times = 0
W
wuzewu 已提交
62 63 64 65

        if replace and os.path.exists(file_name):
            os.remove(file_name)

W
wuzewu 已提交
66 67 68 69
        while not (os.path.exists(file_name)):
            if retry_times < retry_limit:
                retry_times += 1
            else:
W
wuzewu 已提交
70 71 72
                tips = "Cannot download {0} within retry limit {1}".format(
                    url, retry_limit)
                return False, tips, None
W
wuzewu 已提交
73 74 75 76 77 78 79 80 81 82
            r = requests.get(url, stream=True)
            total_length = r.headers.get('content-length')

            if total_length is None:
                with open(file_name, 'wb') as f:
                    shutil.copyfileobj(r.raw, f)
            else:
                with open(file_name, 'wb') as f:
                    dl = 0
                    total_length = int(total_length)
W
wuzewu 已提交
83
                    starttime = time.time()
W
wuzewu 已提交
84
                    if print_progress:
Z
Zeyu Chen 已提交
85
                        print("Downloading %s" % save_name)
W
wuzewu 已提交
86 87 88
                    for data in r.iter_content(chunk_size=4096):
                        dl += len(data)
                        f.write(data)
W
wuzewu 已提交
89 90
                        if print_progress:
                            done = int(50 * dl / total_length)
W
wuzewu 已提交
91 92 93
                            progress(
                                "[%-50s] %.2f%%" %
                                ('=' * done, float(dl / total_length * 100)))
94
                if print_progress:
W
wuzewu 已提交
95
                    progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
W
wuzewu 已提交
96

W
wuzewu 已提交
97
        tips = "File %s download completed!" % (file_name)
W
wuzewu 已提交
98
        return True, tips, file_name
W
wuzewu 已提交
99

W
wuzewu 已提交
100 101 102 103 104
    def uncompress(self,
                   file,
                   dirname=None,
                   delete_file=False,
                   print_progress=False):
W
wuzewu 已提交
105
        dirname = os.path.dirname(file) if dirname is None else dirname
W
wuzewu 已提交
106 107
        if print_progress:
            print("Uncompress %s" % file)
W
wuzewu 已提交
108 109
        with tarfile.open(file, "r:gz") as tar:
            file_names = tar.getnames()
W
wuzewu 已提交
110
            size = len(file_names) - 1
W
wuzewu 已提交
111
            module_dir = os.path.join(dirname, file_names[0])
W
wuzewu 已提交
112 113 114
            for index, file_name in enumerate(file_names):
                if print_progress:
                    done = int(50 * float(index) / size)
W
wuzewu 已提交
115 116
                    progress("[%-50s] %.2f%%" % ('=' * done,
                                                 float(index / size * 100)))
W
wuzewu 已提交
117 118
                tar.extract(file_name, dirname)

W
wuzewu 已提交
119
            if print_progress:
W
wuzewu 已提交
120
                progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
W
wuzewu 已提交
121 122 123
        if delete_file:
            os.remove(file)

W
wuzewu 已提交
124
        return True, "File %s uncompress completed!" % file, module_dir
W
wuzewu 已提交
125 126 127

    def download_file_and_uncompress(self,
                                     url,
W
wuzewu 已提交
128
                                     save_path,
W
wuzewu 已提交
129 130
                                     save_name=None,
                                     retry_limit=3,
131
                                     delete_file=True,
W
wuzewu 已提交
132 133
                                     print_progress=False,
                                     replace=False):
W
wuzewu 已提交
134
        result, tips_1, file = self.download_file(
W
wuzewu 已提交
135 136 137
            url=url,
            save_path=save_path,
            save_name=save_name,
138
            retry_limit=retry_limit,
W
wuzewu 已提交
139 140
            print_progress=print_progress,
            replace=replace)
W
wuzewu 已提交
141 142
        if not result:
            return result, tips_1, file
W
wuzewu 已提交
143 144
        result, tips_2, file = self.uncompress(
            file, delete_file=delete_file, print_progress=print_progress)
W
wuzewu 已提交
145 146
        if not result:
            return result, tips_2, file
W
wuzewu 已提交
147 148 149
        if save_name:
            save_name = os.path.join(save_path, save_name)
            shutil.move(file, save_name)
W
wuzewu 已提交
150 151
            return result, "%s\n%s" % (tips_1, tips_2), save_name
        return result, "%s\n%s" % (tips_1, tips_2), file
W
wuzewu 已提交
152

W
wuzewu 已提交
153 154

default_downloader = Downloader()