# coding:utf-8 # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import paddlehub.env as hubenv from paddle.utils.download import get_path_from_url from paddlehub.utils import log, utils, xarfile def download_data(url): save_name = os.path.basename(url).split('.')[0] output_path = os.path.join(hubenv.DATA_HOME, save_name) if not os.path.exists(output_path): get_path_from_url(url, hubenv.DATA_HOME) def _wrapper(Dataset): return Dataset return _wrapper class Downloader: def download_file_and_uncompress(self, url: str, save_path: str, print_progress: bool): with utils.generate_tempdir() as _dir: if print_progress: with log.ProgressBar('Download {}'.format(url)) as bar: for path, ds, ts in utils.download_with_progress(url=url, path=_dir): bar.update(float(ds) / ts) else: path = utils.download(url=url, path=_dir) if print_progress: with log.ProgressBar('Decompress {}'.format(path)) as bar: for path, ds, ts in xarfile.unarchive_with_progress(name=path, path=save_path): bar.update(float(ds) / ts) else: path = xarfile.unarchive(name=path, path=save_path) default_downloader = Downloader()