diff --git a/paddlehub/module/manager.py b/paddlehub/module/manager.py index d2027596cb3048938a0c7e99d69ca908d0ca49f0..e59893b076a99611835868908f8fe9e6e2d0bec5 100644 --- a/paddlehub/module/manager.py +++ b/paddlehub/module/manager.py @@ -150,7 +150,9 @@ class LocalModuleManager(object): archive: str = None, url: str = None, version: str = None, - source: str = None) -> HubModule: + source: str = None, + update: bool = False, + branch: str = None) -> HubModule: ''' Install a HubModule from name or directory or archive file or url. When installing with the name parameter, if a module that meets the conditions (both name and version) already installed, the installation step will be @@ -167,7 +169,7 @@ class LocalModuleManager(object): if name: lock = filelock.FileLock(os.path.join(TMP_HOME, name)) with lock: - hub_module_cls = self.search(name) + hub_module_cls = self.search(name, source, branch) if hub_module_cls and hub_module_cls.version.match(version): directory = self._get_normalized_path(hub_module_cls.name) if version: @@ -177,7 +179,9 @@ class LocalModuleManager(object): msg = 'Module {} already installed in {}'.format(hub_module_cls.name, directory) log.logger.info(msg) return hub_module_cls - return self._install_from_name(name, version, source) + if source: + return self._install_from_source(name, version, source, update, branch) + return self._install_from_name(name, version) elif directory: return self._install_from_directory(directory) elif archive: @@ -201,19 +205,30 @@ class LocalModuleManager(object): log.logger.info('Successfully uninstalled {}'.format(name)) return True - def search(self, name: str) -> HubModule: + def search(self, name: str, source: str = None, branch: str = None) -> HubModule: '''Return HubModule If a HubModule with a specific name is found, otherwise None.''' + module = None + if name in self._local_modules: - return self._local_modules[name] + module = self._local_modules[name] + else: + module_dir = self._get_normalized_path(name) + if os.path.exists(module_dir): + try: + module = self._local_modules[name] = HubModule.load(module_dir) + except: + log.logger.warning('An error was encountered while loading {}'.format(name)) - module_dir = self._get_normalized_path(name) - if os.path.exists(module_dir): - try: - self._local_modules[name] = HubModule.load(module_dir) - return self._local_modules[name] - except: - log.logger.warning('An error was encountered while loading {}'.format(name)) - return None + if not module: + return None + + if source and source != module.source: + return None + + if branch and branch != module.branch: + return None + + return module def list(self) -> List[HubModule]: '''List all installed HubModule.''' @@ -231,23 +246,17 @@ class LocalModuleManager(object): return self._install_from_archive(file) - def _install_from_name(self, name: str, version: str = None, source: str = None) -> HubModule: + def _install_from_name(self, name: str, version: str = None) -> HubModule: '''Install HubModule by name search result''' - if name in self._local_modules: - if self._local_modules[name].version.match(version): - return self._local_modules[name] - - result = module_server.search_module(name=name, version=version, source=source) + result = module_server.search_module(name=name, version=version) for item in result: if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version): - if source or 'source' in item: - return self._install_from_source(result) return self._install_from_url(item['url']) - module_infos = module_server.get_module_info(name=name, source=source) + module_infos = module_server.get_module_info(name=name) # The HubModule with the specified name cannot be found if not module_infos: - raise HubModuleNotFoundError(name=name, version=version, source=source) + raise HubModuleNotFoundError(name=name, version=version) valid_infos = {} if version: @@ -260,29 +269,43 @@ class LocalModuleManager(object): # Cannot find a HubModule that meets the version if valid_infos: raise EnvironmentMismatchError(name=name, info=valid_infos, version=version) - raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source) + raise HubModuleNotFoundError(name=name, info=module_infos, version=version) - def _install_from_source(self, source: str) -> HubModule: + def _install_from_source(self, name: str, version: str, source: str, update: bool = False, + branch: str = None) -> HubModule: '''Install a HubModule from Git Repo''' - name = source['name'] - cls_name = source['class'] - path = source['path'] - # uninstall local module - if self.search(name): - self.uninstall(name) - - os.makedirs(self._get_normalized_path(name)) - module_file = os.path.join(self._get_normalized_path(name), 'module.py') - - # Generate a module.py file to reference objects from Git Repo - with open(module_file, 'w') as file: - file.write('import sys\n\n') - file.write('sys.path.insert(0, \'{}\')\n'.format(path)) - file.write('from hubconf import {}\n'.format(cls_name)) - file.write('sys.path.pop(0)\n') - - self._local_modules[name] = HubModule.load(self._get_normalized_path(name)) - return self._local_modules[name] + result = module_server.search_module(name=name, source=source, version=version, update=update, branch=branch) + for item in result: + if item['name'] == name and item['version'].match(version): + + # uninstall local module + if self.search(name): + self.uninstall(name) + + installed_path = self._get_normalized_path(name) + if not os.path.exists(installed_path): + os.makedirs(installed_path) + module_file = os.path.join(installed_path, 'module.py') + + # Generate a module.py file to reference objects from Git Repo + with open(module_file, 'w') as file: + file.write('import sys\n\n') + file.write('sys.path.insert(0, \'{}\')\n'.format(item['path'])) + file.write('from hubconf import {}\n'.format(item['class'])) + file.write('sys.path.pop(0)\n') + + source_info_file = os.path.join(installed_path, '_source_info.yaml') + with open(source_info_file, 'w') as file: + file.write('source: {}\n'.format(source)) + file.write('branch: {}'.format(branch)) + + self._local_modules[name] = HubModule.load(installed_path) + if version: + log.logger.info('Successfully installed {}-{}'.format(name, version)) + else: + log.logger.info('Successfully installed {}'.format(name)) + return self._local_modules[name] + raise HubModuleNotFoundError(name=name, version=version, source=source) def _install_from_directory(self, directory: str) -> HubModule: '''Install a HubModule from directory containing module.py''' diff --git a/paddlehub/module/module.py b/paddlehub/module/module.py index 175605d31417a0c2fae271e00441731a75b1de8e..0a150f7ecb84b3efbad1e7761cc811ac62e69340 100644 --- a/paddlehub/module/module.py +++ b/paddlehub/module/module.py @@ -22,7 +22,7 @@ from typing import Callable, Generic, List, Optional from easydict import EasyDict -from paddlehub.utils import log, utils +from paddlehub.utils import parser, log, utils from paddlehub.compat.module.module_v1 import ModuleV1 @@ -62,11 +62,17 @@ class Module(object): ''' ''' - def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs): + def __new__(cls, + name: str = None, + directory: str = None, + version: str = None, + source: str = None, + update: bool = False, + **kwargs): if cls.__name__ == 'Module': # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx') if name: - module = cls.init_with_name(name=name, version=version, **kwargs) + module = cls.init_with_name(name=name, version=version, source=source, update=update, **kwargs) elif directory: module = cls.init_with_directory(directory=directory, **kwargs) else: @@ -81,7 +87,7 @@ class Module(object): if directory.endswith(os.sep): directory = directory[:-1] - # If module description file existed, try to load as ModuleV1 + # If the module description file existed, try to load as ModuleV1 desc_file = os.path.join(directory, 'module_desc.pb') if os.path.exists(desc_file): return ModuleV1.load(directory) @@ -99,6 +105,15 @@ class Module(object): raise InvalidHubModule(directory) user_module_cls.directory = directory + + source_info_file = os.path.join(directory, '_source_info.yaml') + if os.path.exists(source_info_file): + info = parser.yaml_parser.parse(source_info_file) + user_module_cls.source = info.get('source', '') + user_module_cls.branch = info.get('branch', '') + else: + user_module_cls.source = '' + user_module_cls.branch = '' return user_module_cls @classmethod @@ -128,14 +143,20 @@ class Module(object): raise InvalidHubModule(directory) @classmethod - def init_with_name(cls, name: str, version: str = None, **kwargs): + def init_with_name(cls, + name: str, + version: str = None, + source: str = None, + update: bool = False, + branch: str = None, + **kwargs): ''' ''' from paddlehub.module.manager import LocalModuleManager manager = LocalModuleManager() - user_module_cls = manager.search(name) + user_module_cls = manager.search(name, source=source, branch=branch) if not user_module_cls or not user_module_cls.version.match(version): - user_module_cls = manager.install(name=name, version=version) + user_module_cls = manager.install(name=name, version=version, source=source, update=update, branch=branch) directory = manager._get_normalized_path(user_module_cls.name) diff --git a/paddlehub/server/git_source.py b/paddlehub/server/git_source.py index 906373c90893febb16a4e6924464572aed5dcf42..2e80eea535b5ca50a7f08f159879d7a36087a4e2 100644 --- a/paddlehub/server/git_source.py +++ b/paddlehub/server/git_source.py @@ -19,14 +19,13 @@ import os import sys from collections import OrderedDict from typing import List -from urllib.parse import urlparse import git from git import Repo -from paddlehub.module.module import Module as HubModule +from paddlehub.module.module import RunModule from paddlehub.env import SOURCES_HOME -from paddlehub.utils import log +from paddlehub.utils import log, utils class GitSource(object): @@ -40,9 +39,8 @@ class GitSource(object): def __init__(self, url: str, path: str = None): self.url = url - self._parse_result = urlparse(self.url) + self.path = os.path.join(SOURCES_HOME, utils.md5(url)) - self.path = os.path.join(SOURCES_HOME, self._parse_result.path[1:]) if self.path.endswith('.git'): self.path = self.path[:-4] @@ -56,8 +54,21 @@ class GitSource(object): self.hub_modules = OrderedDict() self.load_hub_modules() + def checkout(self, branch: str): + try: + self.repo.git.checkout(branch) + # reload modules + self.load_hub_modules() + except: + log.logger.warning('An error occurred while checkout {}'.format(self.path)) + def update(self): - self.repo.remote().pull() + try: + self.repo.remote().pull(self.repo.branches[0]) + # reload modules + self.load_hub_modules() + except: + log.logger.warning('An error occurred while update {}'.format(self.path)) def load_hub_modules(self): if 'hubconf' in sys.modules: @@ -68,11 +79,12 @@ class GitSource(object): py_module = importlib.import_module('hubconf') for _item, _cls in inspect.getmembers(py_module, inspect.isclass): _item = py_module.__dict__[_item] - if issubclass(_item, HubModule): + if issubclass(_item, RunModule): self.hub_modules[_item.name] = _item except: - raise + self.hub_modules = OrderedDict() log.logger.warning('An error occurred while loading {}'.format(self.path)) + sys.path.remove(self.path) def search_module(self, name: str, version: str = None) -> List[dict]: diff --git a/paddlehub/server/server.py b/paddlehub/server/server.py index 60b0ab0b00f54f00c121b093b8a9f396b2e51a67..c07ac1c8cbced117b59f248962cc04431e749dd8 100644 --- a/paddlehub/server/server.py +++ b/paddlehub/server/server.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import List from paddlehub.server import ServerSource, GitSource +from paddlehub.utils import utils PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub' @@ -27,7 +28,7 @@ class HubServer(object): def __init__(self): self.sources = OrderedDict() - def _generate_source(self, url: str, source_type: str = 'server'): + def _generate_source(self, url: str, source_type: str = 'git'): if source_type == 'server': source = ServerSource(url) elif source_type == 'git': @@ -36,16 +37,29 @@ class HubServer(object): raise RuntimeError() return source - def add_source(self, url: str, key: str = None, source_type: str = 'server'): + def _get_source_key(self, url: str): + return 'source_{}'.format(utils.md5(url)) + + def add_source(self, url: str, source_type: str = 'git'): '''Add a module source(GitSource or ServerSource)''' - key = "source_{}".format(len(self.sources)) if not key else key + key = self._get_source_key(url) self.sources[key] = self._generate_source(url, source_type) def remove_source(self, url: str = None, key: str = None): '''Remove a module source''' self.sources.pop(key) - def search_module(self, name: str, version: str = None, source: str = None) -> List[dict]: + def get_source(self, url: str): + '''''' + key = self._get_source_key(url) + return self.sources.get(key, None) + + def search_module(self, + name: str, + version: str = None, + source: str = None, + update: bool = False, + branch: str = None) -> List[dict]: ''' Search PaddleHub module @@ -53,9 +67,16 @@ class HubServer(object): name(str) : PaddleHub module name version(str) : PaddleHub module version ''' - return self.search_resource(type='module', name=name, version=version, source=source) - - def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> List[dict]: + return self.search_resource( + type='module', name=name, version=version, source=source, update=update, branch=branch) + + def search_resource(self, + type: str, + name: str, + version: str = None, + source: str = None, + update: bool = False, + branch: str = None) -> List[dict]: ''' Search PaddleHub Resource @@ -66,6 +87,12 @@ class HubServer(object): ''' sources = self.sources.values() if not source else [self._generate_source(source)] for source in sources: + if isinstance(source, GitSource) and update: + source.update() + + if isinstance(source, GitSource) and branch: + source.checkout(branch) + result = source.search_resource(name=name, type=type, version=version) if result: return result @@ -83,4 +110,4 @@ class HubServer(object): module_server = HubServer() -module_server.add_source(PADDLEHUB_PUBLIC_SERVER) +module_server.add_source(PADDLEHUB_PUBLIC_SERVER, source_type='server') diff --git a/paddlehub/utils/utils.py b/paddlehub/utils/utils.py index 13e7dd4bd5e1157f3a1867e2e2d1f0195016033a..7495d5710f327c11d2e884befd38b06e7f4be890 100644 --- a/paddlehub/utils/utils.py +++ b/paddlehub/utils/utils.py @@ -16,6 +16,7 @@ import base64 import contextlib import cv2 +import hashlib import importlib import math import os @@ -24,10 +25,10 @@ import sys import time import tempfile import types -import numpy as np from typing import Generator from urllib.parse import urlparse +import numpy as np import packaging.version import paddlehub.env as hubenv @@ -241,6 +242,13 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType: py_module_name(str) : Module name to be loaded ''' sys.path.insert(0, python_path) + + # Delete the cache module to avoid hazards. For example, when the user reinstalls a HubModule, + # if the cache is not cleared, then what the user gets at this time is actually the HubModule + # before uninstallation, this can cause some strange problems, e.g, fail to load model parameters. + if py_module_name in sys.modules: + sys.modules.pop(py_module_name) + py_module = importlib.import_module(py_module_name) sys.path.pop(0) @@ -277,3 +285,10 @@ def sys_stdout_encoding() -> str: if encoding is None: encoding = get_platform_default_encoding() return encoding + + +def md5(text: str): + ''' + ''' + md5code = hashlib.md5(text.encode()) + return md5code.hexdigest()