diff --git a/imperative/python/megengine/data/dataset/vision/cifar.py b/imperative/python/megengine/data/dataset/vision/cifar.py index 16e68a226838a09d7f41d4ca33f81d1c411ca099..500b47b557496c8fce689f22a583070110d7cb3a 100644 --- a/imperative/python/megengine/data/dataset/vision/cifar.py +++ b/imperative/python/megengine/data/dataset/vision/cifar.py @@ -106,9 +106,7 @@ class CIFAR10(VisionDataset): def download(self): url = self.url_path + self.raw_file_name - load_raw_data_from_url( - url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout - ) + load_raw_data_from_url(url, self.raw_file_name, self.raw_file_md5, self.root) self.process() def untar(self, file_path, dirs): diff --git a/imperative/python/megengine/data/dataset/vision/mnist.py b/imperative/python/megengine/data/dataset/vision/mnist.py index efa81628bd0c25c5d6f5a9e9e8fd6874a2073cd5..ae0d9435a244121254858d3bff84616aa2cf0aaa 100644 --- a/imperative/python/megengine/data/dataset/vision/mnist.py +++ b/imperative/python/megengine/data/dataset/vision/mnist.py @@ -118,7 +118,7 @@ class MNIST(VisionDataset): def download(self): for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): url = self.url_path + file_name - load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) + load_raw_data_from_url(url, file_name, md5, self.root) def process(self, train): # load raw files and transform them into meta data and datasets Tuple(np.array) diff --git a/imperative/python/megengine/data/dataset/vision/utils.py b/imperative/python/megengine/data/dataset/vision/utils.py index ed077842145d7bd341af0efcaa6d120939294f72..ca878ce99ac6c313c95cb8f0e54ebdae8a0af277 100644 --- a/imperative/python/megengine/data/dataset/vision/utils.py +++ b/imperative/python/megengine/data/dataset/vision/utils.py @@ -27,9 +27,7 @@ def _default_dataset_root(): return default_dataset_root -def load_raw_data_from_url( - url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int -): +def load_raw_data_from_url(url: str, filename: str, target_md5: str, raw_data_dir: str): cached_file = os.path.join(raw_data_dir, filename) logger.debug( "load_raw_data_from_url: downloading to or using cached %s ...", cached_file @@ -41,7 +39,7 @@ def load_raw_data_from_url( " File may be downloaded multiple times. We recommend\n" " users to download in single process first." ) - md5 = download_from_url(url, cached_file, http_read_timeout=timeout) + md5 = download_from_url(url, cached_file) else: md5 = calculate_md5(cached_file) if target_md5 == md5: diff --git a/imperative/python/megengine/hub/hub.py b/imperative/python/megengine/hub/hub.py index 953714bbc0dd32a5df8d7eb3f004650026052850..bad0dec8fee481429d4188e18e08606b49ec2c6e 100644 --- a/imperative/python/megengine/hub/hub.py +++ b/imperative/python/megengine/hub/hub.py @@ -25,7 +25,6 @@ from .const import ( DEFAULT_PROTOCOL, ENV_MGE_HOME, ENV_XDG_CACHE_HOME, - HTTP_READ_TIMEOUT, HUBCONF, HUBDEPENDENCY, ) @@ -263,14 +262,14 @@ def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: " File may be downloaded multiple times. We recommend\n" " users to download in single process first." ) - download_from_url(url, cached_file, HTTP_READ_TIMEOUT) + download_from_url(url, cached_file) state_dict = _mge_load_serialized(cached_file) return state_dict class pretrained: - r"""Decorator which helps to download pretrained weights from the given url. + r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s). For example, we can decorate a resnet18 function as follows diff --git a/imperative/python/megengine/utils/http_download.py b/imperative/python/megengine/utils/http_download.py index 6342be4873b38fe434e495862c7a8ac3e99358c9..4774253523ccc9e5b1fc9113e9531ac9d3196a9f 100644 --- a/imperative/python/megengine/utils/http_download.py +++ b/imperative/python/megengine/utils/http_download.py @@ -12,6 +12,7 @@ import shutil from tempfile import NamedTemporaryFile import requests +from megfile import smart_copy, smart_getmd5, smart_getsize from tqdm import tqdm from ..logger import get_logger @@ -26,41 +27,21 @@ class HTTPDownloadError(BaseException): r"""The class that represents http request error.""" -def download_from_url(url: str, dst: str, http_read_timeout=120): +class Bar: + def __init__(self, total=100): + self._bar = tqdm(total=total, unit="iB", unit_scale=True, ncols=80) + + def __call__(self, bytes_num): + self._bar.update(bytes_num) + + +def download_from_url(url: str, dst: str): r"""Downloads file from given url to ``dst``. Args: url: source URL. dst: saving path. - http_read_timeout: how many seconds to wait for data before giving up. """ dst = os.path.expanduser(dst) - dst_dir = os.path.dirname(dst) - - resp = requests.get( - url, timeout=(HTTP_CONNECTION_TIMEOUT, http_read_timeout), stream=True - ) - if resp.status_code != 200: - raise HTTPDownloadError("An error occured when downloading from {}".format(url)) - - md5 = hashlib.md5() - total_size = int(resp.headers.get("Content-Length", 0)) - bar = tqdm( - total=total_size, unit="iB", unit_scale=True, ncols=80 - ) # pylint: disable=blacklisted-name - try: - with NamedTemporaryFile("w+b", delete=False, suffix=".tmp", dir=dst_dir) as f: - logger.info("Download file to temp file %s", f.name) - for chunk in resp.iter_content(CHUNK_SIZE): - if not chunk: - break - bar.update(len(chunk)) - f.write(chunk) - md5.update(chunk) - bar.close() - shutil.move(f.name, dst) - finally: - # ensure tmp file is removed - if os.path.exists(f.name): - os.remove(f.name) - return md5.hexdigest() + smart_copy(url, dst, callback=Bar(total=smart_getsize(url))) + return smart_getmd5(dst) diff --git a/imperative/python/requires.txt b/imperative/python/requires.txt index 670193dcafa09caab35e2d284d774b16674a07f2..58a806c05713288a481e8482760c4916b4ad8cca 100644 --- a/imperative/python/requires.txt +++ b/imperative/python/requires.txt @@ -8,3 +8,4 @@ redispy deprecated mprop wheel +megfile>=0.0.10 \ No newline at end of file