提交 315b908c 编写于 作者: M Megvii Engine Team 提交者: XindaH

feat(imperative): mge hub pretrained support s3

GitOrigin-RevId: a48e107623e9992ba15c84a55959dbb414df2117
上级 c96dbd29
...@@ -106,9 +106,7 @@ class CIFAR10(VisionDataset): ...@@ -106,9 +106,7 @@ class CIFAR10(VisionDataset):
def download(self): def download(self):
url = self.url_path + self.raw_file_name url = self.url_path + self.raw_file_name
load_raw_data_from_url( load_raw_data_from_url(url, self.raw_file_name, self.raw_file_md5, self.root)
url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout
)
self.process() self.process()
def untar(self, file_path, dirs): def untar(self, file_path, dirs):
......
...@@ -118,7 +118,7 @@ class MNIST(VisionDataset): ...@@ -118,7 +118,7 @@ class MNIST(VisionDataset):
def download(self): def download(self):
for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5): for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5):
url = self.url_path + file_name url = self.url_path + file_name
load_raw_data_from_url(url, file_name, md5, self.root, self.timeout) load_raw_data_from_url(url, file_name, md5, self.root)
def process(self, train): def process(self, train):
# load raw files and transform them into meta data and datasets Tuple(np.array) # load raw files and transform them into meta data and datasets Tuple(np.array)
......
...@@ -27,9 +27,7 @@ def _default_dataset_root(): ...@@ -27,9 +27,7 @@ def _default_dataset_root():
return default_dataset_root return default_dataset_root
def load_raw_data_from_url( def load_raw_data_from_url(url: str, filename: str, target_md5: str, raw_data_dir: str):
url: str, filename: str, target_md5: str, raw_data_dir: str, timeout: int
):
cached_file = os.path.join(raw_data_dir, filename) cached_file = os.path.join(raw_data_dir, filename)
logger.debug( logger.debug(
"load_raw_data_from_url: downloading to or using cached %s ...", cached_file "load_raw_data_from_url: downloading to or using cached %s ...", cached_file
...@@ -41,7 +39,7 @@ def load_raw_data_from_url( ...@@ -41,7 +39,7 @@ def load_raw_data_from_url(
" File may be downloaded multiple times. We recommend\n" " File may be downloaded multiple times. We recommend\n"
" users to download in single process first." " users to download in single process first."
) )
md5 = download_from_url(url, cached_file, http_read_timeout=timeout) md5 = download_from_url(url, cached_file)
else: else:
md5 = calculate_md5(cached_file) md5 = calculate_md5(cached_file)
if target_md5 == md5: if target_md5 == md5:
......
...@@ -25,7 +25,6 @@ from .const import ( ...@@ -25,7 +25,6 @@ from .const import (
DEFAULT_PROTOCOL, DEFAULT_PROTOCOL,
ENV_MGE_HOME, ENV_MGE_HOME,
ENV_XDG_CACHE_HOME, ENV_XDG_CACHE_HOME,
HTTP_READ_TIMEOUT,
HUBCONF, HUBCONF,
HUBDEPENDENCY, HUBDEPENDENCY,
) )
...@@ -263,14 +262,14 @@ def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: ...@@ -263,14 +262,14 @@ def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
" File may be downloaded multiple times. We recommend\n" " File may be downloaded multiple times. We recommend\n"
" users to download in single process first." " users to download in single process first."
) )
download_from_url(url, cached_file, HTTP_READ_TIMEOUT) download_from_url(url, cached_file)
state_dict = _mge_load_serialized(cached_file) state_dict = _mge_load_serialized(cached_file)
return state_dict return state_dict
class pretrained: class pretrained:
r"""Decorator which helps to download pretrained weights from the given url. r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s).
For example, we can decorate a resnet18 function as follows For example, we can decorate a resnet18 function as follows
......
...@@ -12,6 +12,7 @@ import shutil ...@@ -12,6 +12,7 @@ import shutil
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import requests import requests
from megfile import smart_copy, smart_getmd5, smart_getsize
from tqdm import tqdm from tqdm import tqdm
from ..logger import get_logger from ..logger import get_logger
...@@ -26,41 +27,21 @@ class HTTPDownloadError(BaseException): ...@@ -26,41 +27,21 @@ class HTTPDownloadError(BaseException):
r"""The class that represents http request error.""" r"""The class that represents http request error."""
def download_from_url(url: str, dst: str, http_read_timeout=120): class Bar:
def __init__(self, total=100):
self._bar = tqdm(total=total, unit="iB", unit_scale=True, ncols=80)
def __call__(self, bytes_num):
self._bar.update(bytes_num)
def download_from_url(url: str, dst: str):
r"""Downloads file from given url to ``dst``. r"""Downloads file from given url to ``dst``.
Args: Args:
url: source URL. url: source URL.
dst: saving path. dst: saving path.
http_read_timeout: how many seconds to wait for data before giving up.
""" """
dst = os.path.expanduser(dst) dst = os.path.expanduser(dst)
dst_dir = os.path.dirname(dst) smart_copy(url, dst, callback=Bar(total=smart_getsize(url)))
return smart_getmd5(dst)
resp = requests.get(
url, timeout=(HTTP_CONNECTION_TIMEOUT, http_read_timeout), stream=True
)
if resp.status_code != 200:
raise HTTPDownloadError("An error occured when downloading from {}".format(url))
md5 = hashlib.md5()
total_size = int(resp.headers.get("Content-Length", 0))
bar = tqdm(
total=total_size, unit="iB", unit_scale=True, ncols=80
) # pylint: disable=blacklisted-name
try:
with NamedTemporaryFile("w+b", delete=False, suffix=".tmp", dir=dst_dir) as f:
logger.info("Download file to temp file %s", f.name)
for chunk in resp.iter_content(CHUNK_SIZE):
if not chunk:
break
bar.update(len(chunk))
f.write(chunk)
md5.update(chunk)
bar.close()
shutil.move(f.name, dst)
finally:
# ensure tmp file is removed
if os.path.exists(f.name):
os.remove(f.name)
return md5.hexdigest()
...@@ -8,3 +8,4 @@ redispy ...@@ -8,3 +8,4 @@ redispy
deprecated deprecated
mprop mprop
wheel wheel
megfile>=0.0.10
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册