未验证 提交 4e460d7b 编写于 作者: W Wenyu 提交者: GitHub

Add hub Module for easy to use pre-trained models. (#31873)

* add Hub Module for easy to use pre-trained models.
*   support list, load, help fucntions.
*   support load models by github, gitee, local 
Co-authored-by: NLielinJiang <jianglielin@baidu.com>
上级 92dc9b2b
......@@ -300,6 +300,8 @@ from .hapi import Model
from .hapi import callbacks
from .hapi import summary
from .hapi import flops
from .hapi import hub
import paddle.text
import paddle.vision
......
......@@ -15,6 +15,7 @@
from . import logger
from . import callbacks
from . import model_summary
from . import hub
from . import model
from .model import *
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import sys
import shutil
import zipfile
from paddle.utils.download import get_path_from_url
DEFAULT_CACHE_DIR = '~/.cache'
VAR_DEPENDENCY = 'dependencies'
MODULE_HUBCONF = 'hubconf.py'
HUB_DIR = os.path.expanduser(os.path.join('~', '.cache', 'paddle', 'hub'))
def _remove_if_exists(path):
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
else:
shutil.rmtree(path)
def _import_module(name, repo_dir):
sys.path.insert(0, repo_dir)
try:
hub_module = __import__(name)
sys.modules.pop(name)
except ImportError:
sys.path.remove(repo_dir)
raise RuntimeError(
'Cannot import `{}`, please make sure `{}`.py in repo root dir'.
format(name, name))
sys.path.remove(repo_dir)
return hub_module
def _git_archive_link(repo_owner, repo_name, branch, source):
if source == 'github':
return 'https://github.com/{}/{}/archive/{}.zip'.format(
repo_owner, repo_name, branch)
elif source == 'gitee':
return 'https://gitee.com/{}/{}/repository/archive/{}.zip'.format(
repo_owner, repo_name, branch)
def _parse_repo_info(repo, source):
branch = 'main' if source == 'github' else 'master'
if ':' in repo:
repo_info, branch = repo.split(':')
else:
repo_info = repo
repo_owner, repo_name = repo_info.split('/')
return repo_owner, repo_name, branch
def _make_dirs(dirname):
try:
from pathlib import Path
except ImportError:
from pathlib2 import Path
Path(dirname).mkdir(exist_ok=True)
def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'):
# Setup hub_dir to save downloaded files
hub_dir = HUB_DIR
_make_dirs(hub_dir)
# Parse github/gitee repo information
repo_owner, repo_name, branch = _parse_repo_info(repo, source)
# Github allows branch name with slash '/',
# this causes confusion with path on both Linux and Windows.
# Backslash is not allowed in Github branch name so no need to
# to worry about it.
normalized_br = branch.replace('/', '_')
# Github renames folder repo/v1.x.x to repo-1.x.x
# We don't know the repo name before downloading the zip file
# and inspect name from it.
# To check if cached repo exists, we need to normalize folder names.
repo_dir = os.path.join(hub_dir,
'_'.join([repo_owner, repo_name, normalized_br]))
use_cache = (not force_reload) and os.path.exists(repo_dir)
if use_cache:
if verbose:
sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
else:
cached_file = os.path.join(hub_dir, normalized_br + '.zip')
_remove_if_exists(cached_file)
url = _git_archive_link(repo_owner, repo_name, branch, source=source)
get_path_from_url(url, hub_dir, decompress=False)
with zipfile.ZipFile(cached_file) as cached_zipfile:
extraced_repo_name = cached_zipfile.infolist()[0].filename
extracted_repo = os.path.join(hub_dir, extraced_repo_name)
_remove_if_exists(extracted_repo)
# Unzip the code and rename the base folder
cached_zipfile.extractall(hub_dir)
_remove_if_exists(cached_file)
_remove_if_exists(repo_dir)
# rename the repo
shutil.move(extracted_repo, repo_dir)
return repo_dir
def _load_entry_from_hubconf(m, name):
'''load entry from hubconf
'''
if not isinstance(name, str):
raise ValueError(
'Invalid input: model should be a str of function name')
func = getattr(m, name, None)
if func is None or not callable(func):
raise RuntimeError('Cannot find callable {} in hubconf'.format(name))
return func
def _check_module_exists(name):
try:
__import__(name)
return True
except ImportError:
return False
def _check_dependencies(m):
dependencies = getattr(m, VAR_DEPENDENCY, None)
if dependencies is not None:
missing_deps = [
pkg for pkg in dependencies if not _check_module_exists(pkg)
]
if len(missing_deps):
raise RuntimeError('Missing dependencies: {}'.format(', '.join(
missing_deps)))
def list(repo_dir, source='github', force_reload=False):
r"""
List all entrypoints available in `github` hubconf.
Args:
repo_dir(str): github or local path
github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional
tag/branch. The default branch is `main` if not specified.
local path (str): local repo path
source (str): `github` | `gitee` | `local`, default is `github`
force_reload (bool, optional): whether to discard the existing cache and force a fresh download, default is `False`.
Returns:
entrypoints: a list of available entrypoint names
Example:
```python
import paddle
paddle.hub.list('lyuwenyu/paddlehub_demo:main', source='github', force_reload=False)
```
"""
if source not in ('github', 'gitee', 'local'):
raise ValueError(
'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'.
format(source))
if source in ('github', 'gitee'):
repo_dir = _get_cache_or_reload(
repo_dir, force_reload, True, source=source)
hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)
entrypoints = [
f for f in dir(hub_module)
if callable(getattr(hub_module, f)) and not f.startswith('_')
]
return entrypoints
def help(repo_dir, model, source='github', force_reload=False):
"""
Show help information of model
Args:
repo_dir(str): github or local path
github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional
tag/branch. The default branch is `main` if not specified.
local path (str): local repo path
model (str): model name
source (str): `github` | `gitee` | `local`, default is `github`
force_reload (bool, optional): default is `False`
Return:
docs
Example:
```python
import paddle
paddle.hub.help('lyuwenyu/paddlehub_demo:main', model='MM', source='github')
```
"""
if source not in ('github', 'gitee', 'local'):
raise ValueError(
'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'.
format(source))
if source in ('github', 'gitee'):
repo_dir = _get_cache_or_reload(
repo_dir, force_reload, True, source=source)
hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)
entry = _load_entry_from_hubconf(hub_module, model)
return entry.__doc__
def load(repo_dir, model, source='github', force_reload=False, **kwargs):
"""
Load model
Args:
repo_dir(str): github or local path
github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional
tag/branch. The default branch is `main` if not specified.
local path (str): local repo path
model (str): model name
source (str): `github` | `gitee` | `local`, default is `github`
force_reload (bool, optional), default is `False`
**kwargs: parameters using for model
Return:
paddle model
Example:
```python
import paddle
paddle.hub.load('lyuwenyu/paddlehub_demo:main', model='MM', source='github')
```
"""
if source not in ('github', 'gitee', 'local'):
raise ValueError(
'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'.
format(source))
if source in ('github', 'gitee'):
repo_dir = _get_cache_or_reload(
repo_dir, force_reload, True, source=source)
hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)
_check_dependencies(hub_module)
entry = _load_entry_from_hubconf(hub_module, model)
return entry(**kwargs)
......@@ -49,3 +49,4 @@ set_tests_properties(test_vision_models PROPERTIES TIMEOUT 120)
set_tests_properties(test_dataset_uci_housing PROPERTIES TIMEOUT 120)
set_tests_properties(test_dataset_imdb PROPERTIES TIMEOUT 300)
set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 600)
set_tests_properties(test_hapi_hub PROPERTIES TIMEOUT 300)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
dependencies = ['paddle']
import paddle
from test_hapi_hub_model import MM as _MM
def MM(out_channels=8, pretrained=False):
'''This is a test demo for paddle hub
'''
return _MM(out_channels)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import os
import paddle
from paddle.hapi import hub
import numpy as np
class TestHub(unittest.TestCase):
def setUp(self, ):
self.local_repo = os.path.dirname(os.path.abspath(__file__))
self.github_repo = 'lyuwenyu/paddlehub_demo:main'
def testLoad(self, ):
model = hub.load(
self.local_repo, model='MM', source='local', out_channels=8)
data = paddle.rand((1, 3, 100, 100))
out = model(data)
np.testing.assert_equal(out.shape, [1, 8, 50, 50])
model = hub.load(
self.github_repo, model='MM', source='github', force_reload=True)
model = hub.load(
self.github_repo,
model='MM',
source='github',
force_reload=False,
pretrained=False)
model = hub.load(
self.github_repo.split(':')[0],
model='MM',
source='github',
force_reload=False,
pretrained=False)
model = hub.load(
self.github_repo,
model='MM',
source='github',
force_reload=False,
pretrained=True,
out_channels=8)
data = paddle.ones((1, 3, 2, 2))
out = model(data)
gt = np.array([
1.53965068, 0., 0., 1.39455748, 0.72066200, 0.19773030, 2.09201908,
0.37345418
])
np.testing.assert_equal(out.shape, [1, 8, 1, 1])
np.testing.assert_almost_equal(
out.numpy(), gt.reshape(1, 8, 1, 1), decimal=5)
def testHelp(self, ):
docs1 = hub.help(
self.local_repo,
model='MM',
source='local', )
docs2 = hub.help(
self.github_repo, model='MM', source='github', force_reload=False)
assert docs1 == docs2 == 'This is a test demo for paddle hub\n ', ''
def testList(self, ):
models1 = hub.list(
self.local_repo,
source='local',
force_reload=False, )
models2 = hub.list(
self.github_repo,
source='github',
force_reload=False, )
assert models1 == models2 == ['MM'], ''
def testExcept(self, ):
with self.assertRaises(ValueError):
_ = hub.help(
self.github_repo,
model='MM',
source='github-test',
force_reload=False)
with self.assertRaises(ValueError):
_ = hub.load(
self.github_repo,
model='MM',
source='github-test',
force_reload=False)
with self.assertRaises(ValueError):
_ = hub.list(
self.github_repo, source='github-test', force_reload=False)
with self.assertRaises(ValueError):
_ = hub.load(
self.local_repo, model=123, source='local', force_reload=False)
with self.assertRaises(RuntimeError):
_ = hub.load(
self.local_repo,
model='123',
source='local',
force_reload=False)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class MM(nn.Layer):
def __init__(self, out_channels):
super(MM, self).__init__()
self.conv = nn.Conv2D(3, out_channels, 3, 2, 1)
def forward(self, x):
out = self.conv(x)
out = F.relu(out)
return out
......@@ -117,7 +117,11 @@ def _get_unique_endpoints(trainer_endpoints):
return unique_endpoints
def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
def get_path_from_url(url,
root_dir,
md5sum=None,
check_exist=True,
decompress=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
......@@ -152,7 +156,8 @@ def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
time.sleep(1)
if ParallelEnv().current_endpoint in unique_endpoints:
if tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath):
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册