未验证 提交 14755b3b 编写于 作者: B Bin Long 提交者: GitHub

Merge pull request #272 from ShenYuhan/add_search_tip

add version match tips
......@@ -21,7 +21,7 @@ from paddlehub.common import utils
from paddlehub.common.downloader import default_downloader
from paddlehub.module.manager import default_module_manager
from paddlehub.commands.base_command import BaseCommand
from paddlehub.commands.cml_utils import TablePrinter
from paddlehub.common.cml_utils import TablePrinter
class ListCommand(BaseCommand):
......
......@@ -22,7 +22,7 @@ import argparse
from paddlehub.common import utils
from paddlehub.common.hub_server import default_hub_server
from paddlehub.commands.base_command import BaseCommand, ENTRY
from paddlehub.commands.cml_utils import TablePrinter
from paddlehub.common.cml_utils import TablePrinter
class SearchCommand(BaseCommand):
......
......@@ -22,7 +22,7 @@ import argparse
from paddlehub.common import utils
from paddlehub.commands.base_command import BaseCommand, ENTRY
from paddlehub.commands.cml_utils import TablePrinter
from paddlehub.common.cml_utils import TablePrinter
from paddlehub.module.manager import default_module_manager
from paddlehub.module.module import Module
from paddlehub.io.parser import yaml_parser
......
......@@ -148,6 +148,21 @@ class HubServer(object):
self.search_resource(
resource_key=module_key, resource_type="Model", update=update)
def search_module_info(self, module_key):
try:
payload = {'name': module_key}
api_url = srv_utils.uri_path(self.get_server_url(), 'info')
r = srv_utils.hub_request(api_url, payload)
if r['status'] == 0 and len(r['data']) > 0:
return [(item['raw_name'], item['version'],
item['paddle_version'], item["hub_version"])
for item in r['data']["info"]]
except:
if self.config.get('debug', False):
raise
else:
pass
def get_resource_url(self,
resource_name,
resource_type=None,
......
......@@ -257,3 +257,38 @@ def sys_stdout_encoding():
if encoding is None:
encoding = get_platform_default_encoding()
return encoding
def version_sum(version):
"""
get sum(version), eg: version_sum(1.4.5) = 1*100*100*100 + 4*100*100 + 5*100
:param version: string("1.3.6")
:return:
"""
sum = 0
version_list = version.split(".")
for i in version_list:
sum = (sum + int(i)) * 100
return sum
def sort_version_key(version_a, version_b):
if version_sum(version_a[1]) > version_sum(version_b[1]):
return -1
elif version_sum(version_a[1]) == version_sum(version_b[1]):
return 0
else:
return 1
def strflist_version(version_list):
version_list = version_list[1:-1].split(",")
result = ""
if version_list[0] != "-1.0.0":
result = ">" + version_list[0]
if version_list[1] != "99.0.0":
if result != "":
result = result + ", " + "<" + version_list[1]
else:
result = "<" + version_list[1]
return result if result != "" else "-"
......@@ -19,6 +19,8 @@ from __future__ import print_function
import os
import shutil
from functools import cmp_to_key
import tarfile
from paddlehub.common import utils
......@@ -26,6 +28,7 @@ from paddlehub.common import srv_utils
from paddlehub.common.downloader import default_downloader
from paddlehub.common.hub_server import default_hub_server
from paddlehub.common.dir import MODULE_HOME
from paddlehub.common.cml_utils import TablePrinter
from paddlehub.module import module_desc_pb2
import paddlehub as hub
from paddlehub.common.logger import logger
......@@ -155,11 +158,79 @@ class LocalModuleManager(object):
module_dir)
return True, tips, self.modules_dict[module_name]
search_result = hub.default_hub_server.get_module_url(
module_name, version=module_version, extra=extra)
name = search_result.get('name', None)
url = search_result.get('url', None)
md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None)
if not url or (module_version is not None and installed_module_version
!= module_version) or (name != module_name):
if default_hub_server._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network."
return False, tips, None
module_versions_info = default_hub_server.search_module_info(
module_name)
if module_versions_info is not None and len(
module_versions_info) > 0:
if utils.is_windows():
placeholders = [20, 8, 14, 14]
else:
placeholders = [30, 8, 16, 16]
tp = TablePrinter(
titles=[
"ResourceName", "Version", "PaddlePaddle", "PaddleHub"
],
placeholders=placeholders)
module_versions_info.sort(
key=cmp_to_key(utils.sort_version_key))
for resource_name, resource_version, paddle_version, \
hub_version in module_versions_info:
colors = ["yellow", None, None, None]
tp.add_line(
contents=[
resource_name, resource_version,
utils.strflist_version(paddle_version),
utils.strflist_version(hub_version)
],
colors=colors)
tips = "The version of PaddlePaddle or PaddleHub " \
"can not match module, please upgrade your " \
"PaddlePaddle or PaddleHub according to the form " \
"below." + tp.get_text()
else:
tips = "Can't find module %s" % module_name
if module_version:
tips += " with version %s" % module_version
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
return False, tips, None
result, tips, module_zip_file = default_downloader.download_file(
url=url,
save_path=hub.CACHE_HOME,
save_name=module_name,
replace=True,
print_progress=True)
result, tips, module_dir = default_downloader.uncompress(
file=module_zip_file,
dirname=MODULE_HOME,
delete_file=True,
print_progress=True)
if module_dir:
with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp:
fp.write(md5_value)
if md5_value:
with open(
os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp:
fp.write(md5_value)
save_path = os.path.join(MODULE_HOME, module_name)
if os.path.exists(save_path):
shutil.move(save_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册