download.py 1.9 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# coding:utf-8
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

H
haoyuying 已提交
16 17
import os

W
wuzewu 已提交
18
import paddlehub.env as hubenv
H
haoyuying 已提交
19
from paddle.utils.download import get_path_from_url
W
wuzewu 已提交
20
from paddlehub.utils import log, utils, xarfile
H
haoyuying 已提交
21 22 23 24


def download_data(url):
    save_name = os.path.basename(url).split('.')[0]
W
wuzewu 已提交
25
    output_path = os.path.join(hubenv.DATA_HOME, save_name)
H
haoyuying 已提交
26 27

    if not os.path.exists(output_path):
W
wuzewu 已提交
28
        get_path_from_url(url, hubenv.DATA_HOME)
H
haoyuying 已提交
29 30 31 32 33

    def _wrapper(Dataset):
        return Dataset

    return _wrapper
W
wuzewu 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54


class Downloader:
    def download_file_and_uncompress(self, url: str, save_path: str, print_progress: bool):
        with utils.generate_tempdir() as _dir:
            if print_progress:
                with log.ProgressBar('Download {}'.format(url)) as bar:
                    for path, ds, ts in utils.download_with_progress(url=url, path=_dir):
                        bar.update(float(ds) / ts)
            else:
                path = utils.download(url=url, path=_dir)

            if print_progress:
                with log.ProgressBar('Decompress {}'.format(path)) as bar:
                    for path, ds, ts in xarfile.unarchive_with_progress(name=path, path=save_path):
                        bar.update(float(ds) / ts)
            else:
                path = xarfile.unarchive(name=path, path=save_path)


default_downloader = Downloader()