提交 48363091 编写于 作者: W wuzewu

Update git source

上级 8e7b7669
...@@ -150,7 +150,9 @@ class LocalModuleManager(object): ...@@ -150,7 +150,9 @@ class LocalModuleManager(object):
archive: str = None, archive: str = None,
url: str = None, url: str = None,
version: str = None, version: str = None,
source: str = None) -> HubModule: source: str = None,
update: bool = False,
branch: str = None) -> HubModule:
''' '''
Install a HubModule from name or directory or archive file or url. When installing with the name parameter, if a Install a HubModule from name or directory or archive file or url. When installing with the name parameter, if a
module that meets the conditions (both name and version) already installed, the installation step will be module that meets the conditions (both name and version) already installed, the installation step will be
...@@ -167,7 +169,7 @@ class LocalModuleManager(object): ...@@ -167,7 +169,7 @@ class LocalModuleManager(object):
if name: if name:
lock = filelock.FileLock(os.path.join(TMP_HOME, name)) lock = filelock.FileLock(os.path.join(TMP_HOME, name))
with lock: with lock:
hub_module_cls = self.search(name) hub_module_cls = self.search(name, source, branch)
if hub_module_cls and hub_module_cls.version.match(version): if hub_module_cls and hub_module_cls.version.match(version):
directory = self._get_normalized_path(hub_module_cls.name) directory = self._get_normalized_path(hub_module_cls.name)
if version: if version:
...@@ -177,7 +179,9 @@ class LocalModuleManager(object): ...@@ -177,7 +179,9 @@ class LocalModuleManager(object):
msg = 'Module {} already installed in {}'.format(hub_module_cls.name, directory) msg = 'Module {} already installed in {}'.format(hub_module_cls.name, directory)
log.logger.info(msg) log.logger.info(msg)
return hub_module_cls return hub_module_cls
return self._install_from_name(name, version, source) if source:
return self._install_from_source(name, version, source, update, branch)
return self._install_from_name(name, version)
elif directory: elif directory:
return self._install_from_directory(directory) return self._install_from_directory(directory)
elif archive: elif archive:
...@@ -201,19 +205,30 @@ class LocalModuleManager(object): ...@@ -201,19 +205,30 @@ class LocalModuleManager(object):
log.logger.info('Successfully uninstalled {}'.format(name)) log.logger.info('Successfully uninstalled {}'.format(name))
return True return True
def search(self, name: str) -> HubModule: def search(self, name: str, source: str = None, branch: str = None) -> HubModule:
'''Return HubModule If a HubModule with a specific name is found, otherwise None.''' '''Return HubModule If a HubModule with a specific name is found, otherwise None.'''
module = None
if name in self._local_modules: if name in self._local_modules:
return self._local_modules[name] module = self._local_modules[name]
else:
module_dir = self._get_normalized_path(name)
if os.path.exists(module_dir):
try:
module = self._local_modules[name] = HubModule.load(module_dir)
except:
log.logger.warning('An error was encountered while loading {}'.format(name))
module_dir = self._get_normalized_path(name) if not module:
if os.path.exists(module_dir): return None
try:
self._local_modules[name] = HubModule.load(module_dir) if source and source != module.source:
return self._local_modules[name] return None
except:
log.logger.warning('An error was encountered while loading {}'.format(name)) if branch and branch != module.branch:
return None return None
return module
def list(self) -> List[HubModule]: def list(self) -> List[HubModule]:
'''List all installed HubModule.''' '''List all installed HubModule.'''
...@@ -231,23 +246,17 @@ class LocalModuleManager(object): ...@@ -231,23 +246,17 @@ class LocalModuleManager(object):
return self._install_from_archive(file) return self._install_from_archive(file)
def _install_from_name(self, name: str, version: str = None, source: str = None) -> HubModule: def _install_from_name(self, name: str, version: str = None) -> HubModule:
'''Install HubModule by name search result''' '''Install HubModule by name search result'''
if name in self._local_modules: result = module_server.search_module(name=name, version=version)
if self._local_modules[name].version.match(version):
return self._local_modules[name]
result = module_server.search_module(name=name, version=version, source=source)
for item in result: for item in result:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version): if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
if source or 'source' in item:
return self._install_from_source(result)
return self._install_from_url(item['url']) return self._install_from_url(item['url'])
module_infos = module_server.get_module_info(name=name, source=source) module_infos = module_server.get_module_info(name=name)
# The HubModule with the specified name cannot be found # The HubModule with the specified name cannot be found
if not module_infos: if not module_infos:
raise HubModuleNotFoundError(name=name, version=version, source=source) raise HubModuleNotFoundError(name=name, version=version)
valid_infos = {} valid_infos = {}
if version: if version:
...@@ -260,29 +269,43 @@ class LocalModuleManager(object): ...@@ -260,29 +269,43 @@ class LocalModuleManager(object):
# Cannot find a HubModule that meets the version # Cannot find a HubModule that meets the version
if valid_infos: if valid_infos:
raise EnvironmentMismatchError(name=name, info=valid_infos, version=version) raise EnvironmentMismatchError(name=name, info=valid_infos, version=version)
raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source) raise HubModuleNotFoundError(name=name, info=module_infos, version=version)
def _install_from_source(self, source: str) -> HubModule: def _install_from_source(self, name: str, version: str, source: str, update: bool = False,
branch: str = None) -> HubModule:
'''Install a HubModule from Git Repo''' '''Install a HubModule from Git Repo'''
name = source['name'] result = module_server.search_module(name=name, source=source, version=version, update=update, branch=branch)
cls_name = source['class'] for item in result:
path = source['path'] if item['name'] == name and item['version'].match(version):
# uninstall local module
if self.search(name): # uninstall local module
self.uninstall(name) if self.search(name):
self.uninstall(name)
os.makedirs(self._get_normalized_path(name))
module_file = os.path.join(self._get_normalized_path(name), 'module.py') installed_path = self._get_normalized_path(name)
if not os.path.exists(installed_path):
# Generate a module.py file to reference objects from Git Repo os.makedirs(installed_path)
with open(module_file, 'w') as file: module_file = os.path.join(installed_path, 'module.py')
file.write('import sys\n\n')
file.write('sys.path.insert(0, \'{}\')\n'.format(path)) # Generate a module.py file to reference objects from Git Repo
file.write('from hubconf import {}\n'.format(cls_name)) with open(module_file, 'w') as file:
file.write('sys.path.pop(0)\n') file.write('import sys\n\n')
file.write('sys.path.insert(0, \'{}\')\n'.format(item['path']))
self._local_modules[name] = HubModule.load(self._get_normalized_path(name)) file.write('from hubconf import {}\n'.format(item['class']))
return self._local_modules[name] file.write('sys.path.pop(0)\n')
source_info_file = os.path.join(installed_path, '_source_info.yaml')
with open(source_info_file, 'w') as file:
file.write('source: {}\n'.format(source))
file.write('branch: {}'.format(branch))
self._local_modules[name] = HubModule.load(installed_path)
if version:
log.logger.info('Successfully installed {}-{}'.format(name, version))
else:
log.logger.info('Successfully installed {}'.format(name))
return self._local_modules[name]
raise HubModuleNotFoundError(name=name, version=version, source=source)
def _install_from_directory(self, directory: str) -> HubModule: def _install_from_directory(self, directory: str) -> HubModule:
'''Install a HubModule from directory containing module.py''' '''Install a HubModule from directory containing module.py'''
......
...@@ -22,7 +22,7 @@ from typing import Callable, Generic, List, Optional ...@@ -22,7 +22,7 @@ from typing import Callable, Generic, List, Optional
from easydict import EasyDict from easydict import EasyDict
from paddlehub.utils import log, utils from paddlehub.utils import parser, log, utils
from paddlehub.compat.module.module_v1 import ModuleV1 from paddlehub.compat.module.module_v1 import ModuleV1
...@@ -62,11 +62,17 @@ class Module(object): ...@@ -62,11 +62,17 @@ class Module(object):
''' '''
''' '''
def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs): def __new__(cls,
name: str = None,
directory: str = None,
version: str = None,
source: str = None,
update: bool = False,
**kwargs):
if cls.__name__ == 'Module': if cls.__name__ == 'Module':
# This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx') # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
if name: if name:
module = cls.init_with_name(name=name, version=version, **kwargs) module = cls.init_with_name(name=name, version=version, source=source, update=update, **kwargs)
elif directory: elif directory:
module = cls.init_with_directory(directory=directory, **kwargs) module = cls.init_with_directory(directory=directory, **kwargs)
else: else:
...@@ -81,7 +87,7 @@ class Module(object): ...@@ -81,7 +87,7 @@ class Module(object):
if directory.endswith(os.sep): if directory.endswith(os.sep):
directory = directory[:-1] directory = directory[:-1]
# If module description file existed, try to load as ModuleV1 # If the module description file existed, try to load as ModuleV1
desc_file = os.path.join(directory, 'module_desc.pb') desc_file = os.path.join(directory, 'module_desc.pb')
if os.path.exists(desc_file): if os.path.exists(desc_file):
return ModuleV1.load(directory) return ModuleV1.load(directory)
...@@ -99,6 +105,15 @@ class Module(object): ...@@ -99,6 +105,15 @@ class Module(object):
raise InvalidHubModule(directory) raise InvalidHubModule(directory)
user_module_cls.directory = directory user_module_cls.directory = directory
source_info_file = os.path.join(directory, '_source_info.yaml')
if os.path.exists(source_info_file):
info = parser.yaml_parser.parse(source_info_file)
user_module_cls.source = info.get('source', '')
user_module_cls.branch = info.get('branch', '')
else:
user_module_cls.source = ''
user_module_cls.branch = ''
return user_module_cls return user_module_cls
@classmethod @classmethod
...@@ -128,14 +143,20 @@ class Module(object): ...@@ -128,14 +143,20 @@ class Module(object):
raise InvalidHubModule(directory) raise InvalidHubModule(directory)
@classmethod @classmethod
def init_with_name(cls, name: str, version: str = None, **kwargs): def init_with_name(cls,
name: str,
version: str = None,
source: str = None,
update: bool = False,
branch: str = None,
**kwargs):
''' '''
''' '''
from paddlehub.module.manager import LocalModuleManager from paddlehub.module.manager import LocalModuleManager
manager = LocalModuleManager() manager = LocalModuleManager()
user_module_cls = manager.search(name) user_module_cls = manager.search(name, source=source, branch=branch)
if not user_module_cls or not user_module_cls.version.match(version): if not user_module_cls or not user_module_cls.version.match(version):
user_module_cls = manager.install(name=name, version=version) user_module_cls = manager.install(name=name, version=version, source=source, update=update, branch=branch)
directory = manager._get_normalized_path(user_module_cls.name) directory = manager._get_normalized_path(user_module_cls.name)
......
...@@ -19,14 +19,13 @@ import os ...@@ -19,14 +19,13 @@ import os
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List
from urllib.parse import urlparse
import git import git
from git import Repo from git import Repo
from paddlehub.module.module import Module as HubModule from paddlehub.module.module import RunModule
from paddlehub.env import SOURCES_HOME from paddlehub.env import SOURCES_HOME
from paddlehub.utils import log from paddlehub.utils import log, utils
class GitSource(object): class GitSource(object):
...@@ -40,9 +39,8 @@ class GitSource(object): ...@@ -40,9 +39,8 @@ class GitSource(object):
def __init__(self, url: str, path: str = None): def __init__(self, url: str, path: str = None):
self.url = url self.url = url
self._parse_result = urlparse(self.url) self.path = os.path.join(SOURCES_HOME, utils.md5(url))
self.path = os.path.join(SOURCES_HOME, self._parse_result.path[1:])
if self.path.endswith('.git'): if self.path.endswith('.git'):
self.path = self.path[:-4] self.path = self.path[:-4]
...@@ -56,8 +54,21 @@ class GitSource(object): ...@@ -56,8 +54,21 @@ class GitSource(object):
self.hub_modules = OrderedDict() self.hub_modules = OrderedDict()
self.load_hub_modules() self.load_hub_modules()
def checkout(self, branch: str):
try:
self.repo.git.checkout(branch)
# reload modules
self.load_hub_modules()
except:
log.logger.warning('An error occurred while checkout {}'.format(self.path))
def update(self): def update(self):
self.repo.remote().pull() try:
self.repo.remote().pull(self.repo.branches[0])
# reload modules
self.load_hub_modules()
except:
log.logger.warning('An error occurred while update {}'.format(self.path))
def load_hub_modules(self): def load_hub_modules(self):
if 'hubconf' in sys.modules: if 'hubconf' in sys.modules:
...@@ -68,11 +79,12 @@ class GitSource(object): ...@@ -68,11 +79,12 @@ class GitSource(object):
py_module = importlib.import_module('hubconf') py_module = importlib.import_module('hubconf')
for _item, _cls in inspect.getmembers(py_module, inspect.isclass): for _item, _cls in inspect.getmembers(py_module, inspect.isclass):
_item = py_module.__dict__[_item] _item = py_module.__dict__[_item]
if issubclass(_item, HubModule): if issubclass(_item, RunModule):
self.hub_modules[_item.name] = _item self.hub_modules[_item.name] = _item
except: except:
raise self.hub_modules = OrderedDict()
log.logger.warning('An error occurred while loading {}'.format(self.path)) log.logger.warning('An error occurred while loading {}'.format(self.path))
sys.path.remove(self.path) sys.path.remove(self.path)
def search_module(self, name: str, version: str = None) -> List[dict]: def search_module(self, name: str, version: str = None) -> List[dict]:
......
...@@ -17,6 +17,7 @@ from collections import OrderedDict ...@@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import List from typing import List
from paddlehub.server import ServerSource, GitSource from paddlehub.server import ServerSource, GitSource
from paddlehub.utils import utils
PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub' PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub'
...@@ -27,7 +28,7 @@ class HubServer(object): ...@@ -27,7 +28,7 @@ class HubServer(object):
def __init__(self): def __init__(self):
self.sources = OrderedDict() self.sources = OrderedDict()
def _generate_source(self, url: str, source_type: str = 'server'): def _generate_source(self, url: str, source_type: str = 'git'):
if source_type == 'server': if source_type == 'server':
source = ServerSource(url) source = ServerSource(url)
elif source_type == 'git': elif source_type == 'git':
...@@ -36,16 +37,29 @@ class HubServer(object): ...@@ -36,16 +37,29 @@ class HubServer(object):
raise RuntimeError() raise RuntimeError()
return source return source
def add_source(self, url: str, key: str = None, source_type: str = 'server'): def _get_source_key(self, url: str):
return 'source_{}'.format(utils.md5(url))
def add_source(self, url: str, source_type: str = 'git'):
'''Add a module source(GitSource or ServerSource)''' '''Add a module source(GitSource or ServerSource)'''
key = "source_{}".format(len(self.sources)) if not key else key key = self._get_source_key(url)
self.sources[key] = self._generate_source(url, source_type) self.sources[key] = self._generate_source(url, source_type)
def remove_source(self, url: str = None, key: str = None): def remove_source(self, url: str = None, key: str = None):
'''Remove a module source''' '''Remove a module source'''
self.sources.pop(key) self.sources.pop(key)
def search_module(self, name: str, version: str = None, source: str = None) -> List[dict]: def get_source(self, url: str):
''''''
key = self._get_source_key(url)
return self.sources.get(key, None)
def search_module(self,
name: str,
version: str = None,
source: str = None,
update: bool = False,
branch: str = None) -> List[dict]:
''' '''
Search PaddleHub module Search PaddleHub module
...@@ -53,9 +67,16 @@ class HubServer(object): ...@@ -53,9 +67,16 @@ class HubServer(object):
name(str) : PaddleHub module name name(str) : PaddleHub module name
version(str) : PaddleHub module version version(str) : PaddleHub module version
''' '''
return self.search_resource(type='module', name=name, version=version, source=source) return self.search_resource(
type='module', name=name, version=version, source=source, update=update, branch=branch)
def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> List[dict]:
def search_resource(self,
type: str,
name: str,
version: str = None,
source: str = None,
update: bool = False,
branch: str = None) -> List[dict]:
''' '''
Search PaddleHub Resource Search PaddleHub Resource
...@@ -66,6 +87,12 @@ class HubServer(object): ...@@ -66,6 +87,12 @@ class HubServer(object):
''' '''
sources = self.sources.values() if not source else [self._generate_source(source)] sources = self.sources.values() if not source else [self._generate_source(source)]
for source in sources: for source in sources:
if isinstance(source, GitSource) and update:
source.update()
if isinstance(source, GitSource) and branch:
source.checkout(branch)
result = source.search_resource(name=name, type=type, version=version) result = source.search_resource(name=name, type=type, version=version)
if result: if result:
return result return result
...@@ -83,4 +110,4 @@ class HubServer(object): ...@@ -83,4 +110,4 @@ class HubServer(object):
module_server = HubServer() module_server = HubServer()
module_server.add_source(PADDLEHUB_PUBLIC_SERVER) module_server.add_source(PADDLEHUB_PUBLIC_SERVER, source_type='server')
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import base64 import base64
import contextlib import contextlib
import cv2 import cv2
import hashlib
import importlib import importlib
import math import math
import os import os
...@@ -24,10 +25,10 @@ import sys ...@@ -24,10 +25,10 @@ import sys
import time import time
import tempfile import tempfile
import types import types
import numpy as np
from typing import Generator from typing import Generator
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np
import packaging.version import packaging.version
import paddlehub.env as hubenv import paddlehub.env as hubenv
...@@ -241,6 +242,13 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType: ...@@ -241,6 +242,13 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
py_module_name(str) : Module name to be loaded py_module_name(str) : Module name to be loaded
''' '''
sys.path.insert(0, python_path) sys.path.insert(0, python_path)
# Delete the cache module to avoid hazards. For example, when the user reinstalls a HubModule,
# if the cache is not cleared, then what the user gets at this time is actually the HubModule
# before uninstallation, this can cause some strange problems, e.g, fail to load model parameters.
if py_module_name in sys.modules:
sys.modules.pop(py_module_name)
py_module = importlib.import_module(py_module_name) py_module = importlib.import_module(py_module_name)
sys.path.pop(0) sys.path.pop(0)
...@@ -277,3 +285,10 @@ def sys_stdout_encoding() -> str: ...@@ -277,3 +285,10 @@ def sys_stdout_encoding() -> str:
if encoding is None: if encoding is None:
encoding = get_platform_default_encoding() encoding = get_platform_default_encoding()
return encoding return encoding
def md5(text: str):
'''
'''
md5code = hashlib.md5(text.encode())
return md5code.hexdigest()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册