From 8917fb23e8a6afe554db613112c710fc0cd0e65a Mon Sep 17 00:00:00 2001 From: wuzewu Date: Thu, 24 Sep 2020 19:37:51 +0800 Subject: [PATCH] Add search command --- paddlehub/commands/search.py | 40 +++++++++++++++++++++++++++ paddlehub/module/manager.py | 45 ++++++++++++++++--------------- paddlehub/server/git_source.py | 9 ++++--- paddlehub/server/server.py | 7 ++--- paddlehub/server/server_source.py | 8 +++--- paddlehub/utils/paddlex.py | 12 ++++++--- 6 files changed, 84 insertions(+), 37 deletions(-) diff --git a/paddlehub/commands/search.py b/paddlehub/commands/search.py index e69de29b..d766c9ba 100644 --- a/paddlehub/commands/search.py +++ b/paddlehub/commands/search.py @@ -0,0 +1,40 @@ +#coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +from typing import List + +from paddlehub.commands import register +from paddlehub.module.manager import LocalModuleManager +from paddlehub.server.server import module_server +from paddlehub.utils import log, platform + + +@register(name='hub.search', description='Search PaddleHub pretrained model through model keywords.') +class SearchCommand: + def execute(self, argv: List) -> bool: + argv = '.*' if not argv else argv[0] + + widths = [20, 8, 30] if platform.is_windows() else [30, 8, 40] + table = log.Table(widths=widths) + table.append(*['ModuleName', 'Version', 'Summary'], aligns=['^', '^', '^'], colors=["blue", "blue", "blue"]) + + results = module_server.search_module(name=argv) + for result in results: + table.append(result['name'], result['version'], result['summary']) + + print(table) + return True diff --git a/paddlehub/module/manager.py b/paddlehub/module/manager.py index b1130440..d2027596 100644 --- a/paddlehub/module/manager.py +++ b/paddlehub/module/manager.py @@ -238,28 +238,29 @@ class LocalModuleManager(object): return self._local_modules[name] result = module_server.search_module(name=name, version=version, source=source) - if not result: - module_infos = module_server.get_module_info(name=name, source=source) - # The HubModule with the specified name cannot be found - if not module_infos: - raise HubModuleNotFoundError(name=name, version=version, source=source) - - valid_infos = {} - if version: - for _ver, _info in module_infos.items(): - if utils.Version(_ver).match(version): - valid_infos[_ver] = _info - else: - valid_infos = module_infos.copy() - - # Cannot find a HubModule that meets the version - if valid_infos: - raise EnvironmentMismatchError(name=name, info=valid_infos, version=version) - raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source) - - if source or 'source' in result: - return self._install_from_source(result) - return self._install_from_url(result['url']) + for item in result: + if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version): + if source or 'source' in item: + return self._install_from_source(result) + return self._install_from_url(item['url']) + + module_infos = module_server.get_module_info(name=name, source=source) + # The HubModule with the specified name cannot be found + if not module_infos: + raise HubModuleNotFoundError(name=name, version=version, source=source) + + valid_infos = {} + if version: + for _ver, _info in module_infos.items(): + if utils.Version(_ver).match(version): + valid_infos[_ver] = _info + else: + valid_infos = module_infos.copy() + + # Cannot find a HubModule that meets the version + if valid_infos: + raise EnvironmentMismatchError(name=name, info=valid_infos, version=version) + raise HubModuleNotFoundError(name=name, info=module_infos, version=version, source=source) def _install_from_source(self, source: str) -> HubModule: '''Install a HubModule from Git Repo''' diff --git a/paddlehub/server/git_source.py b/paddlehub/server/git_source.py index 398f3b13..906373c9 100644 --- a/paddlehub/server/git_source.py +++ b/paddlehub/server/git_source.py @@ -18,6 +18,7 @@ import importlib import os import sys from collections import OrderedDict +from typing import List from urllib.parse import urlparse import git @@ -74,7 +75,7 @@ class GitSource(object): log.logger.warning('An error occurred while loading {}'.format(self.path)) sys.path.remove(self.path) - def search_module(self, name: str, version: str = None) -> dict: + def search_module(self, name: str, version: str = None) -> List[dict]: ''' Search PaddleHub module @@ -84,7 +85,7 @@ class GitSource(object): ''' return self.search_resource(type='module', name=name, version=version) - def search_resource(self, type: str, name: str, version: str = None) -> dict: + def search_resource(self, type: str, name: str, version: str = None) -> List[dict]: ''' Search PaddleHub Resource @@ -95,13 +96,13 @@ class GitSource(object): ''' module = self.hub_modules.get(name, None) if module and module.version.match(version): - return { + return [{ 'version': module.version, 'name': module.name, 'path': self.path, 'class': module.__name__, 'source': self.url - } + }] return None @classmethod diff --git a/paddlehub/server/server.py b/paddlehub/server/server.py index fb225bb4..60b0ab0b 100644 --- a/paddlehub/server/server.py +++ b/paddlehub/server/server.py @@ -14,6 +14,7 @@ # limitations under the License. from collections import OrderedDict +from typing import List from paddlehub.server import ServerSource, GitSource @@ -44,7 +45,7 @@ class HubServer(object): '''Remove a module source''' self.sources.pop(key) - def search_module(self, name: str, version: str = None, source: str = None) -> dict: + def search_module(self, name: str, version: str = None, source: str = None) -> List[dict]: ''' Search PaddleHub module @@ -54,7 +55,7 @@ class HubServer(object): ''' return self.search_resource(type='module', name=name, version=version, source=source) - def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> dict: + def search_resource(self, type: str, name: str, version: str = None, source: str = None) -> List[dict]: ''' Search PaddleHub Resource @@ -68,7 +69,7 @@ class HubServer(object): result = source.search_resource(name=name, type=type, version=version) if result: return result - return {} + return [] def get_module_info(self, name: str, source: str = None) -> dict: ''' diff --git a/paddlehub/server/server_source.py b/paddlehub/server/server_source.py index e8842575..e539194d 100644 --- a/paddlehub/server/server_source.py +++ b/paddlehub/server/server_source.py @@ -43,7 +43,7 @@ class ServerSource(object): self._url = url self._timeout = timeout - def search_module(self, name: str, version: str = None) -> dict: + def search_module(self, name: str, version: str = None) -> List[dict]: ''' Search PaddleHub module @@ -53,7 +53,7 @@ class ServerSource(object): ''' return self.search_resource(type='module', name=name, version=version) - def search_resource(self, type: str, name: str, version: str = None) -> dict: + def search_resource(self, type: str, name: str, version: str = None) -> List[dict]: ''' Search PaddleHub Resource @@ -76,9 +76,7 @@ class ServerSource(object): result = self.request(path='search', params=params) if result['status'] == 0 and len(result['data']) > 0: - for item in result['data']: - if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version): - return item + return result['data'] return None def get_module_info(self, name: str) -> dict: diff --git a/paddlehub/utils/paddlex.py b/paddlehub/utils/paddlex.py index bdc3467a..51df24f9 100644 --- a/paddlehub/utils/paddlex.py +++ b/paddlehub/utils/paddlex.py @@ -41,11 +41,17 @@ def download(name: str, save_path: str, version: str = None): if os.path.exists(file): return - resource = module_server.search_resouce(name=name, version=version, type='Model') - if not resource: + resources = module_server.search_resouce(name=name, version=version, type='Model') + if not resources: + raise ResourceNotFoundError(name, version) + + for item in resources: + if item['name'] == name and utils.Version(item['version']).match(version): + url = item['url'] + break + else: raise ResourceNotFoundError(name, version) - url = resource['url'] with utils.generate_tempdir() as _dir: if not os.path.exists(save_path): os.makedirs(save_path) -- GitLab