提交 9a1eac7b 编写于 作者: W wuzewu

Add module install tips

上级 613375e5
......@@ -28,20 +28,83 @@ from paddlehub.utils import xarfile, log, utils, pypi
class HubModuleNotFoundError(Exception):
def __init__(self, name, version=None, source=None):
def __init__(self, name: str, info: dict = None, version: str = None, source: str = None):
self.name = name
self.version = version
self.info = info
self.source = source
def __str__(self):
msg = '{}'.format(self.name)
if self.version:
msg += '-{}'.format(self.version)
if self.source:
msg += ' from {}'.format(self.source)
tips = 'No HubModule named {} was found'.format(msg)
tips = 'No HubModule named {} was found'.format(log.FormattedText(text=msg, color='red'))
if self.info:
sort_infos = sorted(self.info.items(), key=lambda x: utils.Version(x[0]))
table = log.Table()
table.append(
*['Name', 'Version', 'PaddlePaddle Version Required', 'PaddleHub Version Required'],
widths=[15, 10, 35, 35],
aligns=['^', '^', '^', '^'],
colors=['cyan', 'cyan', 'cyan', 'cyan'])
for _ver, info in sort_infos:
paddle_version = 'Any' if not info['paddle_version'] else ', '.join(info['paddle_version'])
hub_version = 'Any' if not info['hub_version'] else ', '.join(info['hub_version'])
table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^'])
tips += ', \n{}'.format(table)
return tips
class EnvironmentMismatchError(Exception):
def __init__(self, name: str, info: dict, version: str = None):
self.name = name
self.version = version
self.info = info
def __str__(self):
msg = '{}'.format(self.name)
if self.version:
msg += '-{}'.format(self.version)
tips = '{} cannot be installed because some conditions are not met'.format(
log.FormattedText(text=msg, color='red'))
if self.info:
sort_infos = sorted(self.info.items(), key=lambda x: utils.Version(x[0]))
table = log.Table()
table.append(
*['Name', 'Version', 'PaddlePaddle Version Required', 'PaddleHub Version Required'],
widths=[15, 10, 35, 35],
aligns=['^', '^', '^', '^'],
colors=['cyan', 'cyan', 'cyan', 'cyan'])
import paddle
import paddlehub
for _ver, info in sort_infos:
paddle_version = 'Any' if not info['paddle_version'] else ', '.join(info['paddle_version'])
for version in info['paddle_version']:
if not utils.Version(paddle.__version__).match(version):
paddle_version = '{}(Mismatch)'.format(paddle_version)
break
hub_version = 'Any' if not info['hub_version'] else ', '.join(info['hub_version'])
for version in info['hub_version']:
if not utils.Version(paddlehub.__version__).match(version):
hub_version = '{}(Mismatch)'.format(hub_version)
break
table.append(self.name, _ver, paddle_version, hub_version, aligns=['^', '^', '^', '^'])
tips += ', \n{}'.format(table)
return tips
......@@ -176,7 +239,23 @@ class LocalModuleManager(object):
result = module_server.search_module(name=name, version=version, source=source)
if not result:
raise HubModuleNotFoundError(name, version, source)
module_infos = module_server.get_module_info(name=name, source=source)
# The HubModule with the specified name cannot be found
if not module_infos:
raise HubModuleNotFoundError(name=name, version=version, source=source)
valid_infos = {}
if version:
for _ver, _info in module_infos.items():
if utils.Version(_ver).match(version):
valid_infos[_ver] = _info
else:
valid_infos = list(module_infos.keys())
# Cannot find a HubModule that meets the version
if valid_infos:
raise EnvironmentMismatchError(name=name, info=valid_infos, version=version)
raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source)
if source or 'source' in result:
return self._install_from_source(result)
......
......@@ -82,9 +82,9 @@ class GitSource(object):
name(str) : PaddleHub module name
version(str) : PaddleHub module version
'''
return self.search_resouce(type='module', name=name, version=version)
return self.search_resource(type='module', name=name, version=version)
def search_resouce(self, type: str, name: str, version: str = None) -> dict:
def search_resource(self, type: str, name: str, version: str = None) -> dict:
'''
Search PaddleHub Resource
......
......@@ -22,6 +22,7 @@ PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub'
class HubServer(object):
'''PaddleHub server'''
def __init__(self):
self.sources = OrderedDict()
......@@ -51,9 +52,9 @@ class HubServer(object):
name(str) : PaddleHub module name
version(str) : PaddleHub module version
'''
return self.search_resouce(type='module', name=name, version=version, source=source)
return self.search_resource(type='module', name=name, version=version, source=source)
def search_resouce(self, type: str, name: str, version: str = None, source: str = None) -> dict:
def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> dict:
'''
Search PaddleHub Resource
......@@ -64,10 +65,20 @@ class HubServer(object):
'''
sources = self.sources.values() if not source else [self._generate_source(source)]
for source in sources:
result = source.search_resouce(name=name, type=type, version=version)
result = source.search_resource(name=name, type=type, version=version)
if result:
return result
return {}
def get_module_info(self, name: str, source: str = None) -> dict:
'''
'''
sources = self.sources.values() if not source else [self._generate_source(source)]
for source in sources:
result = source.get_module_info(name=name)
if result:
return result
return None
return {}
module_server = HubServer()
......
......@@ -14,12 +14,11 @@
# limitations under the License.
import json
import platform
import requests
import sys
from typing import List
import paddlehub
from paddlehub.utils import utils
from paddlehub.utils import utils, platform
class ServerConnectionError(Exception):
......@@ -52,9 +51,9 @@ class ServerSource(object):
name(str) : PaddleHub module name
version(str) : PaddleHub module version
'''
return self.search_resouce(type='module', name=name, version=version)
return self.search_resource(type='module', name=name, version=version)
def search_resouce(self, type: str, name: str, version: str = None) -> dict:
def search_resource(self, type: str, name: str, version: str = None) -> dict:
'''
Search PaddleHub Resource
......@@ -63,36 +62,64 @@ class ServerSource(object):
name(str) : Resource name
version(str) : Resource version
'''
payload = {'environments': {}}
params = {'environments': platform.get_platform_info()}
payload['word'] = name
payload['type'] = type
params['word'] = name
params['type'] = type
if version:
payload['version'] = version
params['version'] = version
# Delay module loading to improve command line speed
import paddle
payload['environments']['hub_version'] = paddlehub.__version__
payload['environments']['paddle_version'] = paddle.__version__
payload['environments']['python_version'] = '.'.join(map(str, sys.version_info[0:3]))
payload['environments']['platform_version'] = platform.version()
payload['environments']['platform_system'] = platform.system()
payload['environments']['platform_architecture'] = platform.architecture()
payload['environments']['platform_type'] = platform.platform()
params['hub_version'] = paddlehub.__version__
params['paddle_version'] = paddle.__version__
api = '{}/search'.format(self._url)
result = self.request(path='search', params=params)
if result['status'] == 0 and len(result['data']) > 0:
for item in result['data']:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return item
return None
def get_module_info(self, name: str) -> dict:
'''
'''
def _convert_version(version: str) -> List:
result = []
# from [1.5.4, 2.0.0] -> 1.5.4,2.0.0
version = version.replace(' ', '')[1:-1]
version = version.split(',')
if version[0] != '-1.0.0':
result.append('>={}'.format(version[0]))
if len(version) > 1:
if version[1] != '99.0.0':
result.append('<={}'.format(version[1]))
return result
params = {'name': name}
result = self.request(path='info', params=params)
if result['status'] == 0 and len(result['data']) > 0:
infos = {}
for _info in result['data']['info']:
infos[_info['version']] = {
'url': _info['url'],
'paddle_version': _convert_version(_info['paddle_version']),
'hub_version': _convert_version(_info['hub_version'])
}
return infos
return {}
def request(self, path: str, params: dict) -> dict:
'''
'''
api = '{}/{}'.format(self._url, path)
try:
result = requests.get(api, payload, timeout=self._timeout)
result = result.json()
if result['status'] == 0 and len(result['data']) > 0:
for item in result['data']:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return item
else:
print(result)
return None
result = requests.get(api, params, timeout=self._timeout)
return result.json()
except requests.exceptions.ConnectionError as e:
raise ServerConnectionError(self._url)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册