提交 315b908c 编写于 作者: M Megvii Engine Team 提交者: XindaH

feat(imperative): mge hub pretrained support s3

GitOrigin-RevId: a48e107623e9992ba15c84a55959dbb414df2117
上级 c96dbd29
......@@ -106,9 +106,7 @@ class CIFAR10(VisionDataset):
def download(self):
url = self.url_path + self.raw_file_name
load_raw_data_from_url(
url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout
)
load_raw_data_from_url(url, self.raw_file_name, self.raw_file_md5, self.root)
self.process()
def untar(self, file_path, dirs):
......
......@@ -118,7 +118,7 @@ class MNIST(VisionDataset):
def download(self):
for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5):
url = self.url_path + file_name
load_raw_data_from_url(url, file_name, md5, self.root, self.timeout)
load_raw_data_from_url(url, file_name, md5, self.root)
def process(self, train):
# load raw files and transform them into meta data and datasets Tuple(np.array)
......
......@@ -27,9 +27,7 @@ def _default_dataset_root():
return default_dataset_root
def load_raw_data_from_url(
url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int
):
def load_raw_data_from_url(url: str, filename: str, target_md5: str, raw_data_dir: str):
cached_file = os.path.join(raw_data_dir, filename)
logger.debug(
"load_raw_data_from_url: downloading to or using cached %s ...", cached_file
......@@ -41,7 +39,7 @@ def load_raw_data_from_url(
" File may be downloaded multiple times. We recommend\n"
" users to download in single process first."
)
md5 = download_from_url(url, cached_file, http_read_timeout=timeout)
md5 = download_from_url(url, cached_file)
else:
md5 = calculate_md5(cached_file)
if target_md5 == md5:
......
......@@ -25,7 +25,6 @@ from .const import (
DEFAULT_PROTOCOL,
ENV_MGE_HOME,
ENV_XDG_CACHE_HOME,
HTTP_READ_TIMEOUT,
HUBCONF,
HUBDEPENDENCY,
)
......@@ -263,14 +262,14 @@ def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
" File may be downloaded multiple times. We recommend\n"
" users to download in single process first."
)
download_from_url(url, cached_file, HTTP_READ_TIMEOUT)
download_from_url(url, cached_file)
state_dict = _mge_load_serialized(cached_file)
return state_dict
class pretrained:
r"""Decorator which helps to download pretrained weights from the given url.
r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s).
For example, we can decorate a resnet18 function as follows
......
......@@ -12,6 +12,7 @@ import shutil
from tempfile import NamedTemporaryFile
import requests
from megfile import smart_copy, smart_getmd5, smart_getsize
from tqdm import tqdm
from ..logger import get_logger
......@@ -26,41 +27,21 @@ class HTTPDownloadError(BaseException):
r"""The class that represents http request error."""
def download_from_url(url: str, dst: str, http_read_timeout=120):
class Bar:
def __init__(self, total=100):
self._bar = tqdm(total=total, unit="iB", unit_scale=True, ncols=80)
def __call__(self, bytes_num):
self._bar.update(bytes_num)
def download_from_url(url: str, dst: str):
r"""Downloads file from given url to ``dst``.
Args:
url: source URL.
dst: saving path.
http_read_timeout: how many seconds to wait for data before giving up.
"""
dst = os.path.expanduser(dst)
dst_dir = os.path.dirname(dst)
resp = requests.get(
url, timeout=(HTTP_CONNECTION_TIMEOUT, http_read_timeout), stream=True
)
if resp.status_code != 200:
raise HTTPDownloadError("An error occured when downloading from {}".format(url))
md5 = hashlib.md5()
total_size = int(resp.headers.get("Content-Length", 0))
bar = tqdm(
total=total_size, unit="iB", unit_scale=True, ncols=80
) # pylint: disable=blacklisted-name
try:
with NamedTemporaryFile("w+b", delete=False, suffix=".tmp", dir=dst_dir) as f:
logger.info("Download file to temp file %s", f.name)
for chunk in resp.iter_content(CHUNK_SIZE):
if not chunk:
break
bar.update(len(chunk))
f.write(chunk)
md5.update(chunk)
bar.close()
shutil.move(f.name, dst)
finally:
# ensure tmp file is removed
if os.path.exists(f.name):
os.remove(f.name)
return md5.hexdigest()
smart_copy(url, dst, callback=Bar(total=smart_getsize(url)))
return smart_getmd5(dst)
......@@ -8,3 +8,4 @@ redispy
deprecated
mprop
wheel
megfile>=0.0.10
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册