提交 48363091 编写于 作者: W wuzewu

Update git source

上级 8e7b7669
......@@ -150,7 +150,9 @@ class LocalModuleManager(object):
archive: str = None,
url: str = None,
version: str = None,
source: str = None) -> HubModule:
source: str = None,
update: bool = False,
branch: str = None) -> HubModule:
'''
Install a HubModule from name or directory or archive file or url. When installing with the name parameter, if a
module that meets the conditions (both name and version) already installed, the installation step will be
......@@ -167,7 +169,7 @@ class LocalModuleManager(object):
if name:
lock = filelock.FileLock(os.path.join(TMP_HOME, name))
with lock:
hub_module_cls = self.search(name)
hub_module_cls = self.search(name, source, branch)
if hub_module_cls and hub_module_cls.version.match(version):
directory = self._get_normalized_path(hub_module_cls.name)
if version:
......@@ -177,7 +179,9 @@ class LocalModuleManager(object):
msg = 'Module {} already installed in {}'.format(hub_module_cls.name, directory)
log.logger.info(msg)
return hub_module_cls
return self._install_from_name(name, version, source)
if source:
return self._install_from_source(name, version, source, update, branch)
return self._install_from_name(name, version)
elif directory:
return self._install_from_directory(directory)
elif archive:
......@@ -201,19 +205,30 @@ class LocalModuleManager(object):
log.logger.info('Successfully uninstalled {}'.format(name))
return True
def search(self, name: str) -> HubModule:
def search(self, name: str, source: str = None, branch: str = None) -> HubModule:
'''Return HubModule If a HubModule with a specific name is found, otherwise None.'''
module = None
if name in self._local_modules:
return self._local_modules[name]
module = self._local_modules[name]
else:
module_dir = self._get_normalized_path(name)
if os.path.exists(module_dir):
try:
module = self._local_modules[name] = HubModule.load(module_dir)
except:
log.logger.warning('An error was encountered while loading {}'.format(name))
module_dir = self._get_normalized_path(name)
if os.path.exists(module_dir):
try:
self._local_modules[name] = HubModule.load(module_dir)
return self._local_modules[name]
except:
log.logger.warning('An error was encountered while loading {}'.format(name))
return None
if not module:
return None
if source and source != module.source:
return None
if branch and branch != module.branch:
return None
return module
def list(self) -> List[HubModule]:
'''List all installed HubModule.'''
......@@ -231,23 +246,17 @@ class LocalModuleManager(object):
return self._install_from_archive(file)
def _install_from_name(self, name: str, version: str = None, source: str = None) -> HubModule:
def _install_from_name(self, name: str, version: str = None) -> HubModule:
'''Install HubModule by name search result'''
if name in self._local_modules:
if self._local_modules[name].version.match(version):
return self._local_modules[name]
result = module_server.search_module(name=name, version=version, source=source)
result = module_server.search_module(name=name, version=version)
for item in result:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
if source or 'source' in item:
return self._install_from_source(result)
return self._install_from_url(item['url'])
module_infos = module_server.get_module_info(name=name, source=source)
module_infos = module_server.get_module_info(name=name)
# The HubModule with the specified name cannot be found
if not module_infos:
raise HubModuleNotFoundError(name=name, version=version, source=source)
raise HubModuleNotFoundError(name=name, version=version)
valid_infos = {}
if version:
......@@ -260,29 +269,43 @@ class LocalModuleManager(object):
# Cannot find a HubModule that meets the version
if valid_infos:
raise EnvironmentMismatchError(name=name, info=valid_infos, version=version)
raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source)
raise HubModuleNotFoundError(name=name, info=module_infos, version=version)
def _install_from_source(self, source: str) -> HubModule:
def _install_from_source(self, name: str, version: str, source: str, update: bool = False,
branch: str = None) -> HubModule:
'''Install a HubModule from Git Repo'''
name = source['name']
cls_name = source['class']
path = source['path']
# uninstall local module
if self.search(name):
self.uninstall(name)
os.makedirs(self._get_normalized_path(name))
module_file = os.path.join(self._get_normalized_path(name), 'module.py')
# Generate a module.py file to reference objects from Git Repo
with open(module_file, 'w') as file:
file.write('import sys\n\n')
file.write('sys.path.insert(0, \'{}\')\n'.format(path))
file.write('from hubconf import {}\n'.format(cls_name))
file.write('sys.path.pop(0)\n')
self._local_modules[name] = HubModule.load(self._get_normalized_path(name))
return self._local_modules[name]
result = module_server.search_module(name=name, source=source, version=version, update=update, branch=branch)
for item in result:
if item['name'] == name and item['version'].match(version):
# uninstall local module
if self.search(name):
self.uninstall(name)
installed_path = self._get_normalized_path(name)
if not os.path.exists(installed_path):
os.makedirs(installed_path)
module_file = os.path.join(installed_path, 'module.py')
# Generate a module.py file to reference objects from Git Repo
with open(module_file, 'w') as file:
file.write('import sys\n\n')
file.write('sys.path.insert(0, \'{}\')\n'.format(item['path']))
file.write('from hubconf import {}\n'.format(item['class']))
file.write('sys.path.pop(0)\n')
source_info_file = os.path.join(installed_path, '_source_info.yaml')
with open(source_info_file, 'w') as file:
file.write('source: {}\n'.format(source))
file.write('branch: {}'.format(branch))
self._local_modules[name] = HubModule.load(installed_path)
if version:
log.logger.info('Successfully installed {}-{}'.format(name, version))
else:
log.logger.info('Successfully installed {}'.format(name))
return self._local_modules[name]
raise HubModuleNotFoundError(name=name, version=version, source=source)
def _install_from_directory(self, directory: str) -> HubModule:
'''Install a HubModule from directory containing module.py'''
......
......@@ -22,7 +22,7 @@ from typing import Callable, Generic, List, Optional
from easydict import EasyDict
from paddlehub.utils import log, utils
from paddlehub.utils import parser, log, utils
from paddlehub.compat.module.module_v1 import ModuleV1
......@@ -62,11 +62,17 @@ class Module(object):
'''
'''
def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs):
def __new__(cls,
name: str = None,
directory: str = None,
version: str = None,
source: str = None,
update: bool = False,
**kwargs):
if cls.__name__ == 'Module':
# This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
if name:
module = cls.init_with_name(name=name, version=version, **kwargs)
module = cls.init_with_name(name=name, version=version, source=source, update=update, **kwargs)
elif directory:
module = cls.init_with_directory(directory=directory, **kwargs)
else:
......@@ -81,7 +87,7 @@ class Module(object):
if directory.endswith(os.sep):
directory = directory[:-1]
# If module description file existed, try to load as ModuleV1
# If the module description file existed, try to load as ModuleV1
desc_file = os.path.join(directory, 'module_desc.pb')
if os.path.exists(desc_file):
return ModuleV1.load(directory)
......@@ -99,6 +105,15 @@ class Module(object):
raise InvalidHubModule(directory)
user_module_cls.directory = directory
source_info_file = os.path.join(directory, '_source_info.yaml')
if os.path.exists(source_info_file):
info = parser.yaml_parser.parse(source_info_file)
user_module_cls.source = info.get('source', '')
user_module_cls.branch = info.get('branch', '')
else:
user_module_cls.source = ''
user_module_cls.branch = ''
return user_module_cls
@classmethod
......@@ -128,14 +143,20 @@ class Module(object):
raise InvalidHubModule(directory)
@classmethod
def init_with_name(cls, name: str, version: str = None, **kwargs):
def init_with_name(cls,
name: str,
version: str = None,
source: str = None,
update: bool = False,
branch: str = None,
**kwargs):
'''
'''
from paddlehub.module.manager import LocalModuleManager
manager = LocalModuleManager()
user_module_cls = manager.search(name)
user_module_cls = manager.search(name, source=source, branch=branch)
if not user_module_cls or not user_module_cls.version.match(version):
user_module_cls = manager.install(name=name, version=version)
user_module_cls = manager.install(name=name, version=version, source=source, update=update, branch=branch)
directory = manager._get_normalized_path(user_module_cls.name)
......
......@@ -19,14 +19,13 @@ import os
import sys
from collections import OrderedDict
from typing import List
from urllib.parse import urlparse
import git
from git import Repo
from paddlehub.module.module import Module as HubModule
from paddlehub.module.module import RunModule
from paddlehub.env import SOURCES_HOME
from paddlehub.utils import log
from paddlehub.utils import log, utils
class GitSource(object):
......@@ -40,9 +39,8 @@ class GitSource(object):
def __init__(self, url: str, path: str = None):
self.url = url
self._parse_result = urlparse(self.url)
self.path = os.path.join(SOURCES_HOME, utils.md5(url))
self.path = os.path.join(SOURCES_HOME, self._parse_result.path[1:])
if self.path.endswith('.git'):
self.path = self.path[:-4]
......@@ -56,8 +54,21 @@ class GitSource(object):
self.hub_modules = OrderedDict()
self.load_hub_modules()
def checkout(self, branch: str):
try:
self.repo.git.checkout(branch)
# reload modules
self.load_hub_modules()
except:
log.logger.warning('An error occurred while checkout {}'.format(self.path))
def update(self):
self.repo.remote().pull()
try:
self.repo.remote().pull(self.repo.branches[0])
# reload modules
self.load_hub_modules()
except:
log.logger.warning('An error occurred while update {}'.format(self.path))
def load_hub_modules(self):
if 'hubconf' in sys.modules:
......@@ -68,11 +79,12 @@ class GitSource(object):
py_module = importlib.import_module('hubconf')
for _item, _cls in inspect.getmembers(py_module, inspect.isclass):
_item = py_module.__dict__[_item]
if issubclass(_item, HubModule):
if issubclass(_item, RunModule):
self.hub_modules[_item.name] = _item
except:
raise
self.hub_modules = OrderedDict()
log.logger.warning('An error occurred while loading {}'.format(self.path))
sys.path.remove(self.path)
def search_module(self, name: str, version: str = None) -> List[dict]:
......
......@@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import List
from paddlehub.server import ServerSource, GitSource
from paddlehub.utils import utils
PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub'
......@@ -27,7 +28,7 @@ class HubServer(object):
def __init__(self):
self.sources = OrderedDict()
def _generate_source(self, url: str, source_type: str = 'server'):
def _generate_source(self, url: str, source_type: str = 'git'):
if source_type == 'server':
source = ServerSource(url)
elif source_type == 'git':
......@@ -36,16 +37,29 @@ class HubServer(object):
raise RuntimeError()
return source
def add_source(self, url: str, key: str = None, source_type: str = 'server'):
def _get_source_key(self, url: str):
return 'source_{}'.format(utils.md5(url))
def add_source(self, url: str, source_type: str = 'git'):
'''Add a module source(GitSource or ServerSource)'''
key = "source_{}".format(len(self.sources)) if not key else key
key = self._get_source_key(url)
self.sources[key] = self._generate_source(url, source_type)
def remove_source(self, url: str = None, key: str = None):
'''Remove a module source'''
self.sources.pop(key)
def search_module(self, name: str, version: str = None, source: str = None) -> List[dict]:
def get_source(self, url: str):
''''''
key = self._get_source_key(url)
return self.sources.get(key, None)
def search_module(self,
name: str,
version: str = None,
source: str = None,
update: bool = False,
branch: str = None) -> List[dict]:
'''
Search PaddleHub module
......@@ -53,9 +67,16 @@ class HubServer(object):
name(str) : PaddleHub module name
version(str) : PaddleHub module version
'''
return self.search_resource(type='module', name=name, version=version, source=source)
def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> List[dict]:
return self.search_resource(
type='module', name=name, version=version, source=source, update=update, branch=branch)
def search_resource(self,
type: str,
name: str,
version: str = None,
source: str = None,
update: bool = False,
branch: str = None) -> List[dict]:
'''
Search PaddleHub Resource
......@@ -66,6 +87,12 @@ class HubServer(object):
'''
sources = self.sources.values() if not source else [self._generate_source(source)]
for source in sources:
if isinstance(source, GitSource) and update:
source.update()
if isinstance(source, GitSource) and branch:
source.checkout(branch)
result = source.search_resource(name=name, type=type, version=version)
if result:
return result
......@@ -83,4 +110,4 @@ class HubServer(object):
module_server = HubServer()
module_server.add_source(PADDLEHUB_PUBLIC_SERVER)
module_server.add_source(PADDLEHUB_PUBLIC_SERVER, source_type='server')
......@@ -16,6 +16,7 @@
import base64
import contextlib
import cv2
import hashlib
import importlib
import math
import os
......@@ -24,10 +25,10 @@ import sys
import time
import tempfile
import types
import numpy as np
from typing import Generator
from urllib.parse import urlparse
import numpy as np
import packaging.version
import paddlehub.env as hubenv
......@@ -241,6 +242,13 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
py_module_name(str) : Module name to be loaded
'''
sys.path.insert(0, python_path)
# Delete the cache module to avoid hazards. For example, when the user reinstalls a HubModule,
# if the cache is not cleared, then what the user gets at this time is actually the HubModule
# before uninstallation, this can cause some strange problems, e.g, fail to load model parameters.
if py_module_name in sys.modules:
sys.modules.pop(py_module_name)
py_module = importlib.import_module(py_module_name)
sys.path.pop(0)
......@@ -277,3 +285,10 @@ def sys_stdout_encoding() -> str:
if encoding is None:
encoding = get_platform_default_encoding()
return encoding
def md5(text: str):
'''
'''
md5code = hashlib.md5(text.encode())
return md5code.hexdigest()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册