提交 8917fb23 编写于 作者: W wuzewu

Add search command

上级 80476dde
#coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from typing import List
from paddlehub.commands import register
from paddlehub.module.manager import LocalModuleManager
from paddlehub.server.server import module_server
from paddlehub.utils import log, platform
@register(name='hub.search', description='Search PaddleHub pretrained model through model keywords.')
class SearchCommand:
def execute(self, argv: List) -> bool:
argv = '.*' if not argv else argv[0]
widths = [20, 8, 30] if platform.is_windows() else [30, 8, 40]
table = log.Table(widths=widths)
table.append(*['ModuleName', 'Version', 'Summary'], aligns=['^', '^', '^'], colors=["blue", "blue", "blue"])
results = module_server.search_module(name=argv)
for result in results:
table.append(result['name'], result['version'], result['summary'])
print(table)
return True
......@@ -238,28 +238,29 @@ class LocalModuleManager(object):
return self._local_modules[name]
result = module_server.search_module(name=name, version=version, source=source)
if not result:
module_infos = module_server.get_module_info(name=name, source=source)
# The HubModule with the specified name cannot be found
if not module_infos:
raise HubModuleNotFoundError(name=name, version=version, source=source)
valid_infos = {}
if version:
for _ver, _info in module_infos.items():
if utils.Version(_ver).match(version):
valid_infos[_ver] = _info
else:
valid_infos = module_infos.copy()
# Cannot find a HubModule that meets the version
if valid_infos:
raise EnvironmentMismatchError(name=name, info=valid_infos, version=version)
raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source)
if source or 'source' in result:
return self._install_from_source(result)
return self._install_from_url(result['url'])
for item in result:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
if source or 'source' in item:
return self._install_from_source(result)
return self._install_from_url(item['url'])
module_infos = module_server.get_module_info(name=name, source=source)
# The HubModule with the specified name cannot be found
if not module_infos:
raise HubModuleNotFoundError(name=name, version=version, source=source)
valid_infos = {}
if version:
for _ver, _info in module_infos.items():
if utils.Version(_ver).match(version):
valid_infos[_ver] = _info
else:
valid_infos = module_infos.copy()
# Cannot find a HubModule that meets the version
if valid_infos:
raise EnvironmentMismatchError(name=name, info=valid_infos, version=version)
raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source)
def _install_from_source(self, source: str) -> HubModule:
'''Install a HubModule from Git Repo'''
......
......@@ -18,6 +18,7 @@ import importlib
import os
import sys
from collections import OrderedDict
from typing import List
from urllib.parse import urlparse
import git
......@@ -74,7 +75,7 @@ class GitSource(object):
log.logger.warning('An error occurred while loading {}'.format(self.path))
sys.path.remove(self.path)
def search_module(self, name: str, version: str = None) -> dict:
def search_module(self, name: str, version: str = None) -> List[dict]:
'''
Search PaddleHub module
......@@ -84,7 +85,7 @@ class GitSource(object):
'''
return self.search_resource(type='module', name=name, version=version)
def search_resource(self, type: str, name: str, version: str = None) -> dict:
def search_resource(self, type: str, name: str, version: str = None) -> List[dict]:
'''
Search PaddleHub Resource
......@@ -95,13 +96,13 @@ class GitSource(object):
'''
module = self.hub_modules.get(name, None)
if module and module.version.match(version):
return {
return [{
'version': module.version,
'name': module.name,
'path': self.path,
'class': module.__name__,
'source': self.url
}
}]
return None
@classmethod
......
......@@ -14,6 +14,7 @@
# limitations under the License.
from collections import OrderedDict
from typing import List
from paddlehub.server import ServerSource, GitSource
......@@ -44,7 +45,7 @@ class HubServer(object):
'''Remove a module source'''
self.sources.pop(key)
def search_module(self, name: str, version: str = None, source: str = None) -> dict:
def search_module(self, name: str, version: str = None, source: str = None) -> List[dict]:
'''
Search PaddleHub module
......@@ -54,7 +55,7 @@ class HubServer(object):
'''
return self.search_resource(type='module', name=name, version=version, source=source)
def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> dict:
def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> List[dict]:
'''
Search PaddleHub Resource
......@@ -68,7 +69,7 @@ class HubServer(object):
result = source.search_resource(name=name, type=type, version=version)
if result:
return result
return {}
return []
def get_module_info(self, name: str, source: str = None) -> dict:
'''
......
......@@ -43,7 +43,7 @@ class ServerSource(object):
self._url = url
self._timeout = timeout
def search_module(self, name: str, version: str = None) -> dict:
def search_module(self, name: str, version: str = None) -> List[dict]:
'''
Search PaddleHub module
......@@ -53,7 +53,7 @@ class ServerSource(object):
'''
return self.search_resource(type='module', name=name, version=version)
def search_resource(self, type: str, name: str, version: str = None) -> dict:
def search_resource(self, type: str, name: str, version: str = None) -> List[dict]:
'''
Search PaddleHub Resource
......@@ -76,9 +76,7 @@ class ServerSource(object):
result = self.request(path='search', params=params)
if result['status'] == 0 and len(result['data']) > 0:
for item in result['data']:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return item
return result['data']
return None
def get_module_info(self, name: str) -> dict:
......
......@@ -41,11 +41,17 @@ def download(name: str, save_path: str, version: str = None):
if os.path.exists(file):
return
resource = module_server.search_resouce(name=name, version=version, type='Model')
if not resource:
resources = module_server.search_resouce(name=name, version=version, type='Model')
if not resources:
raise ResourceNotFoundError(name, version)
for item in resources:
if item['name'] == name and utils.Version(item['version']).match(version):
url = item['url']
break
else:
raise ResourceNotFoundError(name, version)
url = resource['url']
with utils.generate_tempdir() as _dir:
if not os.path.exists(save_path):
os.makedirs(save_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册