提交 30aace46 编写于 作者: W wuzewu

Add `ignore_env_mismatch` in hub.Module.

上级 98f173e0
...@@ -25,11 +25,22 @@ from paddlehub.server.server import CacheUpdater ...@@ -25,11 +25,22 @@ from paddlehub.server.server import CacheUpdater
@register(name='hub.install', description='Install PaddleHub module.') @register(name='hub.install', description='Install PaddleHub module.')
class InstallCommand: class InstallCommand:
def __init__(self):
self.parser = argparse.ArgumentParser(prog='hub install', add_help=True)
self.parser.add_argument(
'--ignore_env_mismatch',
action='store_true',
help='Whether to ignore the environment mismatch when installing the Module.')
def execute(self, argv: List) -> bool: def execute(self, argv: List) -> bool:
if not argv: if not argv:
print("ERROR: You must give at least one module to install.") print("ERROR: You must give at least one module to install.")
return False return False
options = [arg for arg in argv if arg.startswith('-')]
argv = [arg for arg in argv if not arg.startswith('-')]
args = self.parser.parse_args(options)
manager = LocalModuleManager() manager = LocalModuleManager()
for _arg in argv: for _arg in argv:
if os.path.exists(_arg) and os.path.isdir(_arg): if os.path.exists(_arg) and os.path.isdir(_arg):
...@@ -41,5 +52,5 @@ class InstallCommand: ...@@ -41,5 +52,5 @@ class InstallCommand:
name = _arg[0] name = _arg[0]
version = None if len(_arg) == 1 else _arg[1] version = None if len(_arg) == 1 else _arg[1]
CacheUpdater("hub_install", name, version).start() CacheUpdater("hub_install", name, version).start()
manager.install(name=name, version=version) manager.install(name=name, version=version, ignore_env_mismatch=args.ignore_env_mismatch)
return True return True
...@@ -155,7 +155,8 @@ class LocalModuleManager(object): ...@@ -155,7 +155,8 @@ class LocalModuleManager(object):
version: str = None, version: str = None,
source: str = None, source: str = None,
update: bool = False, update: bool = False,
branch: str = None) -> HubModule: branch: str = None,
ignore_env_mismatch: bool = False) -> HubModule:
''' '''
Install a HubModule from name or directory or archive file or url. When installing with the name parameter, if a Install a HubModule from name or directory or archive file or url. When installing with the name parameter, if a
module that meets the conditions (both name and version) already installed, the installation step will be skipped. module that meets the conditions (both name and version) already installed, the installation step will be skipped.
...@@ -168,6 +169,7 @@ class LocalModuleManager(object): ...@@ -168,6 +169,7 @@ class LocalModuleManager(object):
url (str|optional): url points to a archive file containing module code url (str|optional): url points to a archive file containing module code
version (str|optional): module version, use with name parameter version (str|optional): module version, use with name parameter
source (str|optional): source containing module code, use with name paramete source (str|optional): source containing module code, use with name paramete
ignore_env_mismatch (str|optional): Whether to ignore the environment mismatch when installing the Module.
''' '''
if name: if name:
...@@ -185,7 +187,7 @@ class LocalModuleManager(object): ...@@ -185,7 +187,7 @@ class LocalModuleManager(object):
return hub_module_cls return hub_module_cls
if source: if source:
return self._install_from_source(name, version, source, update, branch) return self._install_from_source(name, version, source, update, branch)
return self._install_from_name(name, version) return self._install_from_name(name, version, ignore_env_mismatch)
elif directory: elif directory:
return self._install_from_directory(directory) return self._install_from_directory(directory)
elif archive: elif archive:
...@@ -255,7 +257,7 @@ class LocalModuleManager(object): ...@@ -255,7 +257,7 @@ class LocalModuleManager(object):
return self._install_from_archive(file) return self._install_from_archive(file)
def _install_from_name(self, name: str, version: str = None) -> HubModule: def _install_from_name(self, name: str, version: str = None, ignore_env_mismatch: bool = False) -> HubModule:
'''Install HubModule by name search result''' '''Install HubModule by name search result'''
result = module_server.search_module(name=name, version=version) result = module_server.search_module(name=name, version=version)
for item in result: for item in result:
...@@ -277,7 +279,24 @@ class LocalModuleManager(object): ...@@ -277,7 +279,24 @@ class LocalModuleManager(object):
# Cannot find a HubModule that meets the version # Cannot find a HubModule that meets the version
if valid_infos: if valid_infos:
if not ignore_env_mismatch:
raise EnvironmentMismatchError(name=name, info=valid_infos, version=version) raise EnvironmentMismatchError(name=name, info=valid_infos, version=version)
# If `ignore_env_mismatch` is set, ignore the problem of environmental mismatch, such as PaddlePaddle or PaddleHub
# version incompatibility. This may cause some unexpected problems during installation or running, but it is useful
# in some cases, for example, the development version of PaddlePaddle(with version number `0.0.0`) is installed
# locally.
if version:
if version in valid_infos:
url = valid_infos[version]['url']
else:
raise HubModuleNotFoundError(name=name, info=module_infos, version=version)
else:
# Get the maximum version number.
version = sorted([utils.Version(_v) for _v in valid_infos.keys()])[-1]
url = valid_infos[str(version)]['url']
log.logger.warning('Ignore environmental mismatch of The Module {}-{}'.format(name, version))
return self._install_from_url(url)
raise HubModuleNotFoundError(name=name, info=module_infos, version=version) raise HubModuleNotFoundError(name=name, info=module_infos, version=version)
def _install_from_source(self, name: str, version: str, source: str, update: bool = False, def _install_from_source(self, name: str, version: str, source: str, update: bool = False,
......
...@@ -359,16 +359,16 @@ class Module(object): ...@@ -359,16 +359,16 @@ class Module(object):
name(str): Module name. name(str): Module name.
directory(str|optional): Directory of the module to be loaded, only takes effect when the `name` is not specified. directory(str|optional): Directory of the module to be loaded, only takes effect when the `name` is not specified.
version(str|optional): The version limit of the module, only takes effect when the `name` is specified. When the local version(str|optional): The version limit of the module, only takes effect when the `name` is specified. When the local
Module does not meet the specified version conditions, PaddleHub will re-request the server to Module does not meet the specified version conditions, PaddleHub will re-request the server to download the
download the appropriate Module. Default to None, This means that the local Module will be used. appropriate Module. Default to None, This means that the local Module will be used. If the Module does not exist,
If the Module does not exist, PaddleHub will download the latest version available from the PaddleHub will download the latest version available from the server according to the usage environment.
server according to the usage environment.
source(str|optional): Url of a git repository. If this parameter is specified, PaddleHub will no longer download the source(str|optional): Url of a git repository. If this parameter is specified, PaddleHub will no longer download the
specified Module from the default server, but will look for it in the specified repository. specified Module from the default server, but will look for it in the specified repository. Default to None.
Default to None. update(bool|optional): Whether to update the locally cached git repository, only takes effect when the `source` is
update(bool|optional): Whether to update the locally cached git repository, only takes effect when the `source` specified. Default to False.
is specified. Default to False.
branch(str|optional): The branch of the specified git repository. Default to None. branch(str|optional): The branch of the specified git repository. Default to None.
ignore_env_mismatch(bool|optional): Whether to ignore the environment mismatch when installing the Module. Default to
False.
''' '''
def __new__(cls, def __new__(cls,
...@@ -379,13 +379,20 @@ class Module(object): ...@@ -379,13 +379,20 @@ class Module(object):
source: str = None, source: str = None,
update: bool = False, update: bool = False,
branch: str = None, branch: str = None,
ignore_env_mismatch: bool = False,
**kwargs): **kwargs):
if cls.__name__ == 'Module': if cls.__name__ == 'Module':
from paddlehub.server.server import CacheUpdater from paddlehub.server.server import CacheUpdater
# This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx') # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
if name: if name:
module = cls.init_with_name( module = cls.init_with_name(
name=name, version=version, source=source, update=update, branch=branch, **kwargs) name=name,
version=version,
source=source,
update=update,
branch=branch,
ignore_env_mismatch=ignore_env_mismatch,
**kwargs)
CacheUpdater("update_cache", module=name, version=version).start() CacheUpdater("update_cache", module=name, version=version).start()
elif directory: elif directory:
module = cls.init_with_directory(directory=directory, **kwargs) module = cls.init_with_directory(directory=directory, **kwargs)
...@@ -470,13 +477,20 @@ class Module(object): ...@@ -470,13 +477,20 @@ class Module(object):
source: str = None, source: str = None,
update: bool = False, update: bool = False,
branch: str = None, branch: str = None,
ignore_env_mismatch: bool = False,
**kwargs) -> Union[RunModule, ModuleV1]: **kwargs) -> Union[RunModule, ModuleV1]:
'''Initialize Module according to the specified name.''' '''Initialize Module according to the specified name.'''
from paddlehub.module.manager import LocalModuleManager from paddlehub.module.manager import LocalModuleManager
manager = LocalModuleManager() manager = LocalModuleManager()
user_module_cls = manager.search(name, source=source, branch=branch) user_module_cls = manager.search(name, source=source, branch=branch)
if not user_module_cls or not user_module_cls.version.match(version): if not user_module_cls or not user_module_cls.version.match(version):
user_module_cls = manager.install(name=name, version=version, source=source, update=update, branch=branch) user_module_cls = manager.install(
name=name,
version=version,
source=source,
update=update,
branch=branch,
ignore_env_mismatch=ignore_env_mismatch)
directory = manager._get_normalized_path(user_module_cls.name) directory = manager._get_normalized_path(user_module_cls.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册