From 4e460d7b6475d76dcf27581714bea1cf83571d26 Mon Sep 17 00:00:00 2001 From: Wenyu Date: Sun, 25 Apr 2021 14:22:30 +0800 Subject: [PATCH] Add hub Module for easy to use pre-trained models. (#31873) * add Hub Module for easy to use pre-trained models. * support list, load, help fucntions. * support load models by github, gitee, local Co-authored-by: LielinJiang --- python/paddle/__init__.py | 2 + python/paddle/hapi/__init__.py | 1 + python/paddle/hapi/hub.py | 277 +++++++++++++++++++++ python/paddle/tests/CMakeLists.txt | 1 + python/paddle/tests/hubconf.py | 24 ++ python/paddle/tests/test_hapi_hub.py | 132 ++++++++++ python/paddle/tests/test_hapi_hub_model.py | 29 +++ python/paddle/utils/download.py | 9 +- 8 files changed, 473 insertions(+), 2 deletions(-) create mode 100644 python/paddle/hapi/hub.py create mode 100644 python/paddle/tests/hubconf.py create mode 100644 python/paddle/tests/test_hapi_hub.py create mode 100644 python/paddle/tests/test_hapi_hub_model.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 129a0381125..94091c94bb5 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -300,6 +300,8 @@ from .hapi import Model from .hapi import callbacks from .hapi import summary from .hapi import flops +from .hapi import hub + import paddle.text import paddle.vision diff --git a/python/paddle/hapi/__init__.py b/python/paddle/hapi/__init__.py index 0aea557a28c..6b7672828e6 100644 --- a/python/paddle/hapi/__init__.py +++ b/python/paddle/hapi/__init__.py @@ -15,6 +15,7 @@ from . import logger from . import callbacks from . import model_summary +from . import hub from . import model from .model import * diff --git a/python/paddle/hapi/hub.py b/python/paddle/hapi/hub.py new file mode 100644 index 00000000000..31a8be0944f --- /dev/null +++ b/python/paddle/hapi/hub.py @@ -0,0 +1,277 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import sys +import shutil +import zipfile +from paddle.utils.download import get_path_from_url + +DEFAULT_CACHE_DIR = '~/.cache' +VAR_DEPENDENCY = 'dependencies' +MODULE_HUBCONF = 'hubconf.py' +HUB_DIR = os.path.expanduser(os.path.join('~', '.cache', 'paddle', 'hub')) + + +def _remove_if_exists(path): + if os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) + + +def _import_module(name, repo_dir): + sys.path.insert(0, repo_dir) + try: + hub_module = __import__(name) + sys.modules.pop(name) + except ImportError: + sys.path.remove(repo_dir) + raise RuntimeError( + 'Cannot import `{}`, please make sure `{}`.py in repo root dir'. + format(name, name)) + + sys.path.remove(repo_dir) + + return hub_module + + +def _git_archive_link(repo_owner, repo_name, branch, source): + if source == 'github': + return 'https://github.com/{}/{}/archive/{}.zip'.format( + repo_owner, repo_name, branch) + elif source == 'gitee': + return 'https://gitee.com/{}/{}/repository/archive/{}.zip'.format( + repo_owner, repo_name, branch) + + +def _parse_repo_info(repo, source): + branch = 'main' if source == 'github' else 'master' + if ':' in repo: + repo_info, branch = repo.split(':') + else: + repo_info = repo + repo_owner, repo_name = repo_info.split('/') + return repo_owner, repo_name, branch + + +def _make_dirs(dirname): + try: + from pathlib import Path + except ImportError: + from pathlib2 import Path + Path(dirname).mkdir(exist_ok=True) + + +def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'): + # Setup hub_dir to save downloaded files + hub_dir = HUB_DIR + + _make_dirs(hub_dir) + + # Parse github/gitee repo information + repo_owner, repo_name, branch = _parse_repo_info(repo, source) + # Github allows branch name with slash '/', + # this causes confusion with path on both Linux and Windows. + # Backslash is not allowed in Github branch name so no need to + # to worry about it. + normalized_br = branch.replace('/', '_') + # Github renames folder repo/v1.x.x to repo-1.x.x + # We don't know the repo name before downloading the zip file + # and inspect name from it. + # To check if cached repo exists, we need to normalize folder names. + repo_dir = os.path.join(hub_dir, + '_'.join([repo_owner, repo_name, normalized_br])) + + use_cache = (not force_reload) and os.path.exists(repo_dir) + + if use_cache: + if verbose: + sys.stderr.write('Using cache found in {}\n'.format(repo_dir)) + else: + cached_file = os.path.join(hub_dir, normalized_br + '.zip') + _remove_if_exists(cached_file) + + url = _git_archive_link(repo_owner, repo_name, branch, source=source) + + get_path_from_url(url, hub_dir, decompress=False) + + with zipfile.ZipFile(cached_file) as cached_zipfile: + extraced_repo_name = cached_zipfile.infolist()[0].filename + extracted_repo = os.path.join(hub_dir, extraced_repo_name) + _remove_if_exists(extracted_repo) + # Unzip the code and rename the base folder + cached_zipfile.extractall(hub_dir) + + _remove_if_exists(cached_file) + _remove_if_exists(repo_dir) + # rename the repo + shutil.move(extracted_repo, repo_dir) + + return repo_dir + + +def _load_entry_from_hubconf(m, name): + '''load entry from hubconf + ''' + if not isinstance(name, str): + raise ValueError( + 'Invalid input: model should be a str of function name') + + func = getattr(m, name, None) + + if func is None or not callable(func): + raise RuntimeError('Cannot find callable {} in hubconf'.format(name)) + + return func + + +def _check_module_exists(name): + try: + __import__(name) + return True + except ImportError: + return False + + +def _check_dependencies(m): + dependencies = getattr(m, VAR_DEPENDENCY, None) + + if dependencies is not None: + missing_deps = [ + pkg for pkg in dependencies if not _check_module_exists(pkg) + ] + if len(missing_deps): + raise RuntimeError('Missing dependencies: {}'.format(', '.join( + missing_deps))) + + +def list(repo_dir, source='github', force_reload=False): + r""" + List all entrypoints available in `github` hubconf. + + Args: + repo_dir(str): github or local path + github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional + tag/branch. The default branch is `main` if not specified. + local path (str): local repo path + source (str): `github` | `gitee` | `local`, default is `github` + force_reload (bool, optional): whether to discard the existing cache and force a fresh download, default is `False`. + Returns: + entrypoints: a list of available entrypoint names + + Example: + ```python + import paddle + + paddle.hub.list('lyuwenyu/paddlehub_demo:main', source='github', force_reload=False) + + ``` + """ + if source not in ('github', 'gitee', 'local'): + raise ValueError( + 'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'. + format(source)) + + if source in ('github', 'gitee'): + repo_dir = _get_cache_or_reload( + repo_dir, force_reload, True, source=source) + + hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir) + + entrypoints = [ + f for f in dir(hub_module) + if callable(getattr(hub_module, f)) and not f.startswith('_') + ] + + return entrypoints + + +def help(repo_dir, model, source='github', force_reload=False): + """ + Show help information of model + + Args: + repo_dir(str): github or local path + github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional + tag/branch. The default branch is `main` if not specified. + local path (str): local repo path + model (str): model name + source (str): `github` | `gitee` | `local`, default is `github` + force_reload (bool, optional): default is `False` + Return: + docs + + Example: + ```python + import paddle + + paddle.hub.help('lyuwenyu/paddlehub_demo:main', model='MM', source='github') + ``` + """ + if source not in ('github', 'gitee', 'local'): + raise ValueError( + 'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'. + format(source)) + + if source in ('github', 'gitee'): + repo_dir = _get_cache_or_reload( + repo_dir, force_reload, True, source=source) + + hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir) + + entry = _load_entry_from_hubconf(hub_module, model) + + return entry.__doc__ + + +def load(repo_dir, model, source='github', force_reload=False, **kwargs): + """ + Load model + + Args: + repo_dir(str): github or local path + github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional + tag/branch. The default branch is `main` if not specified. + local path (str): local repo path + model (str): model name + source (str): `github` | `gitee` | `local`, default is `github` + force_reload (bool, optional), default is `False` + **kwargs: parameters using for model + Return: + paddle model + Example: + ```python + import paddle + paddle.hub.load('lyuwenyu/paddlehub_demo:main', model='MM', source='github') + ``` + """ + if source not in ('github', 'gitee', 'local'): + raise ValueError( + 'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'. + format(source)) + + if source in ('github', 'gitee'): + repo_dir = _get_cache_or_reload( + repo_dir, force_reload, True, source=source) + + hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir) + + _check_dependencies(hub_module) + + entry = _load_entry_from_hubconf(hub_module, model) + + return entry(**kwargs) diff --git a/python/paddle/tests/CMakeLists.txt b/python/paddle/tests/CMakeLists.txt index 9a676b6b739..bb572973fdb 100644 --- a/python/paddle/tests/CMakeLists.txt +++ b/python/paddle/tests/CMakeLists.txt @@ -49,3 +49,4 @@ set_tests_properties(test_vision_models PROPERTIES TIMEOUT 120) set_tests_properties(test_dataset_uci_housing PROPERTIES TIMEOUT 120) set_tests_properties(test_dataset_imdb PROPERTIES TIMEOUT 300) set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 600) +set_tests_properties(test_hapi_hub PROPERTIES TIMEOUT 300) diff --git a/python/paddle/tests/hubconf.py b/python/paddle/tests/hubconf.py new file mode 100644 index 00000000000..4b4a853ef2c --- /dev/null +++ b/python/paddle/tests/hubconf.py @@ -0,0 +1,24 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +dependencies = ['paddle'] + +import paddle +from test_hapi_hub_model import MM as _MM + + +def MM(out_channels=8, pretrained=False): + '''This is a test demo for paddle hub + ''' + return _MM(out_channels) diff --git a/python/paddle/tests/test_hapi_hub.py b/python/paddle/tests/test_hapi_hub.py new file mode 100644 index 00000000000..06000d6c833 --- /dev/null +++ b/python/paddle/tests/test_hapi_hub.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import os + +import paddle +from paddle.hapi import hub + +import numpy as np + + +class TestHub(unittest.TestCase): + def setUp(self, ): + self.local_repo = os.path.dirname(os.path.abspath(__file__)) + self.github_repo = 'lyuwenyu/paddlehub_demo:main' + + def testLoad(self, ): + model = hub.load( + self.local_repo, model='MM', source='local', out_channels=8) + + data = paddle.rand((1, 3, 100, 100)) + out = model(data) + np.testing.assert_equal(out.shape, [1, 8, 50, 50]) + + model = hub.load( + self.github_repo, model='MM', source='github', force_reload=True) + + model = hub.load( + self.github_repo, + model='MM', + source='github', + force_reload=False, + pretrained=False) + + model = hub.load( + self.github_repo.split(':')[0], + model='MM', + source='github', + force_reload=False, + pretrained=False) + + model = hub.load( + self.github_repo, + model='MM', + source='github', + force_reload=False, + pretrained=True, + out_channels=8) + + data = paddle.ones((1, 3, 2, 2)) + out = model(data) + + gt = np.array([ + 1.53965068, 0., 0., 1.39455748, 0.72066200, 0.19773030, 2.09201908, + 0.37345418 + ]) + np.testing.assert_equal(out.shape, [1, 8, 1, 1]) + np.testing.assert_almost_equal( + out.numpy(), gt.reshape(1, 8, 1, 1), decimal=5) + + def testHelp(self, ): + docs1 = hub.help( + self.local_repo, + model='MM', + source='local', ) + + docs2 = hub.help( + self.github_repo, model='MM', source='github', force_reload=False) + + assert docs1 == docs2 == 'This is a test demo for paddle hub\n ', '' + + def testList(self, ): + models1 = hub.list( + self.local_repo, + source='local', + force_reload=False, ) + + models2 = hub.list( + self.github_repo, + source='github', + force_reload=False, ) + + assert models1 == models2 == ['MM'], '' + + def testExcept(self, ): + with self.assertRaises(ValueError): + _ = hub.help( + self.github_repo, + model='MM', + source='github-test', + force_reload=False) + + with self.assertRaises(ValueError): + _ = hub.load( + self.github_repo, + model='MM', + source='github-test', + force_reload=False) + + with self.assertRaises(ValueError): + _ = hub.list( + self.github_repo, source='github-test', force_reload=False) + + with self.assertRaises(ValueError): + _ = hub.load( + self.local_repo, model=123, source='local', force_reload=False) + + with self.assertRaises(RuntimeError): + _ = hub.load( + self.local_repo, + model='123', + source='local', + force_reload=False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tests/test_hapi_hub_model.py b/python/paddle/tests/test_hapi_hub_model.py new file mode 100644 index 00000000000..774c7f6f33a --- /dev/null +++ b/python/paddle/tests/test_hapi_hub_model.py @@ -0,0 +1,29 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +class MM(nn.Layer): + def __init__(self, out_channels): + super(MM, self).__init__() + self.conv = nn.Conv2D(3, out_channels, 3, 2, 1) + + def forward(self, x): + out = self.conv(x) + out = F.relu(out) + + return out diff --git a/python/paddle/utils/download.py b/python/paddle/utils/download.py index b7d7d0b5adb..dda8abeff21 100644 --- a/python/paddle/utils/download.py +++ b/python/paddle/utils/download.py @@ -117,7 +117,11 @@ def _get_unique_endpoints(trainer_endpoints): return unique_endpoints -def get_path_from_url(url, root_dir, md5sum=None, check_exist=True): +def get_path_from_url(url, + root_dir, + md5sum=None, + check_exist=True, + decompress=True): """ Download from given url to root_dir. if file or directory specified by url is exists under root_dir, return the path directly, otherwise download @@ -152,7 +156,8 @@ def get_path_from_url(url, root_dir, md5sum=None, check_exist=True): time.sleep(1) if ParallelEnv().current_endpoint in unique_endpoints: - if tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath): + if decompress and (tarfile.is_tarfile(fullpath) or + zipfile.is_zipfile(fullpath)): fullpath = _decompress(fullpath) return fullpath -- GitLab