提交 e08a624b 编写于 作者: L LielinJiang

update download name

上级 e26905fe
......@@ -29,7 +29,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
import logging
logger = logging.getLogger(__name__)
__all__ = ['get_weights_path', 'is_url']
__all__ = ['get_weights_path_from_url', 'is_url']
WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights")
......@@ -45,48 +45,56 @@ def is_url(path):
return path.startswith('http://') or path.startswith('https://')
def get_weights_path(url, md5sum=None):
def get_weights_path_from_url(url, md5sum=None):
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.
Args:
url (str): download url
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded weights.
"""
path, _ = get_path(url, WEIGHTS_HOME, md5sum)
path = get_path_from_url(url, WEIGHTS_HOME, md5sum)
return path
def map_path(url, root_dir):
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def get_path(url, root_dir, md5sum=None, check_exist=True):
def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
assert is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
fullpath = map_path(url, root_dir)
fullpath = _map_path(url, root_dir)
exist_flag = False
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
exist_flag = True
if ParallelEnv().local_rank == 0:
logger.info("Found {}".format(fullpath))
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().local_rank == 0:
fullpath = _download(url, root_dir, md5sum)
else:
while not os.path.exists(fullpath):
time.sleep(1)
return fullpath, exist_flag
return fullpath
def _download(url, path, md5sum=None):
......@@ -109,8 +117,8 @@ def _download(url, path, md5sum=None):
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if ParallelEnv().local_rank == 0:
logger.info("Downloading {} from {}".format(fname, url))
logger.info("Downloading {} from {}".format(fname, url))
req = requests.get(url, stream=True)
if req.status_code != 200:
......@@ -141,8 +149,8 @@ def _download(url, path, md5sum=None):
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
if ParallelEnv().local_rank == 0:
logger.info("File {} md5 checking...".format(fullname))
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
......@@ -150,8 +158,7 @@ def _md5check(fullname, md5sum=None):
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
if ParallelEnv().local_rank == 0:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
......@@ -20,14 +20,15 @@ from paddle.fluid.regularizer import L2Decay
from paddle.fluid.dygraph.nn import Conv2D, BatchNorm, Pool2D, Linear
from hapi.model import Model
from hapi.download import get_weights_path
from hapi.download import get_weights_path_from_url
__all__ = ['DarkNet', 'darknet53']
# {num_layers: (url, md5)}
pretrain_infos = {
53: ('https://paddle-hapi.bj.bcebos.com/models/darknet53.pdparams',
'ca506a90e2efecb9a2093f8ada808708')
model_urls = {
'darknet53':
('https://paddle-hapi.bj.bcebos.com/models/darknet53.pdparams',
'ca506a90e2efecb9a2093f8ada808708')
}
......@@ -213,16 +214,15 @@ class DarkNet(Model):
return out
def _darknet(num_layers=53, pretrained=False, **kwargs):
def _darknet(arch, num_layers=53, pretrained=False, **kwargs):
model = DarkNet(num_layers, **kwargs)
if pretrained:
assert num_layers in pretrain_infos.keys(), \
"DarkNet{} do not have pretrained weights now, " \
"pretrained should be set as False".format(num_layers)
weight_path = get_weights_path(*(pretrain_infos[num_layers]))
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(*(model_urls[arch]))
assert weight_path.endswith('.pdparams'), \
"suffix of weight must be .pdparams"
model.load(weight_path[:-9])
model.load(weight_path)
return model
......@@ -234,4 +234,4 @@ def darknet53(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet,
default True.
"""
return _darknet(53, pretrained, **kwargs)
return _darknet('darknet53', 53, pretrained, **kwargs)
......@@ -20,7 +20,7 @@ from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from hapi.model import Model
from hapi.download import get_weights_path
from hapi.download import get_weights_path_from_url
__all__ = ['MobileNetV1', 'mobilenet_v1']
......@@ -267,11 +267,11 @@ def _mobilenet(arch, pretrained=False, **kwargs):
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
model.load(weight_path)
return model
......
......@@ -19,7 +19,7 @@ from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from hapi.model import Model
from hapi.download import get_weights_path
from hapi.download import get_weights_path_from_url
__all__ = ['MobileNetV2', 'mobilenet_v2']
......@@ -241,11 +241,11 @@ def _mobilenet(arch, pretrained=False, **kwargs):
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
model.load(weight_path)
return model
......
......@@ -23,7 +23,7 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from hapi.model import Model
from hapi.download import get_weights_path
from hapi.download import get_weights_path_from_url
__all__ = [
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'
......@@ -267,11 +267,11 @@ def _resnet(arch, Block, depth, pretrained, **kwargs):
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
model.load(weight_path)
return model
......
......@@ -18,7 +18,7 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from hapi.model import Model
from hapi.download import get_weights_path
from hapi.download import get_weights_path_from_url
__all__ = [
'VGG',
......@@ -128,11 +128,11 @@ def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path(model_urls[arch][0],
model_urls[arch][1])
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])
assert weight_path.endswith(
'.pdparams'), "suffix of weight must be .pdparams"
model.load(weight_path[:-9])
model.load(weight_path)
return model
......
......@@ -26,7 +26,8 @@ from paddle.fluid.optimizer import Momentum
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from hapi.datasets.mnist import MNIST as MnistDataset
from hapi.model import Model, CrossEntropy, Input, set_device
from hapi.model import Model, Input, set_device
from hapi.loss import CrossEntropy
from hapi.metrics import Accuracy
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册