提交 270e44ab 编写于 作者: W wuzewu

search command add colorful output

上级 4b4b48ef
......@@ -22,6 +22,7 @@ from paddlehub.common.logger import logger
from paddlehub.common import utils
from paddlehub.common.hub_server import default_hub_server
from paddlehub.commands.base_command import BaseCommand, ENTRY
from paddlehub.commands.cml_utils import TablePrinter
class SearchCommand(BaseCommand):
......@@ -43,15 +44,23 @@ class SearchCommand(BaseCommand):
self.help()
return False
module_name = argv[0]
module_list = default_hub_server.search_module(module_name)
text = "\n"
text += color_bold_text(
"red", " %-20s\t\t%s\n" % ("ModuleName", "ModuleVersion"))
text += " %-20s\t\t%s\n" % ("--", "--")
for module_name, module_version in module_list:
text += " %-20s\t\t%s\n" % (module_name, module_version)
print(text)
resource_name = argv[0]
resource_list = default_hub_server.search_resource(resource_name)
tp = TablePrinter(
titles=["ResourceName", "Type", "Version", "Summary"],
placeholders=[25, 10, 10, 35])
for resource_name, resource_type, resource_version, resource_summary in resource_list:
if resource_type == "Module":
colors = ["yellow", None, None, None]
else:
colors = ["light_red", None, None, None]
tp.add_line(
contents=[
resource_name, resource_version, resource_type,
resource_summary
],
colors=colors)
print(tp.get_text())
return True
......
......@@ -21,11 +21,10 @@ import time
from paddlehub.common import utils
from paddlehub.common.downloader import default_downloader
from paddlehub.io.reader import csv_reader
from paddlehub.io.reader import yaml_reader
import paddlehub as hub
MODULE_LIST_FILE = "module_list_file.csv"
MODEL_LIST_FILE = "model_list_file.csv"
RESOURCE_LIST_FILE = "resource_list_file.yml"
CACHE_TIME = 60 * 10
......@@ -35,149 +34,115 @@ class HubServer:
server_url = "https://paddlehub.bj.bcebos.com/"
utils.check_url(server_url)
self.server_url = server_url
self._load_module_list_file_if_valid()
self._load_model_list_file_if_valid()
self._load_resource_list_file_if_valid()
def module_list_file_path(self):
return os.path.join(hub.CACHE_HOME, MODULE_LIST_FILE)
def resource_list_file_path(self):
return os.path.join(hub.CACHE_HOME, RESOURCE_LIST_FILE)
def model_list_file_path(self):
return os.path.join(hub.CACHE_HOME, MODEL_LIST_FILE)
def _load_model_list_file_if_valid(self):
self.model_list_file = {}
if not os.path.exists(self.model_list_file_path()):
return False
file_create_time = os.path.getctime(self.model_list_file_path())
now_time = time.time()
# if file is out of date, remove it
if now_time - file_create_time >= CACHE_TIME:
os.remove(self.model_list_file_path())
return False
self.model_list_file = csv_reader.read(self.model_list_file_path())
# if file do not contain necessary data, remove it
if "version" not in self.model_list_file or "model_name" not in self.model_list_file:
self.model_list_file = {}
os.remove(self.model_list_file_path())
def _load_resource_list_file_if_valid(self):
self.resource_list_file = {}
if not os.path.exists(self.resource_list_file_path()):
return False
return True
def _load_module_list_file_if_valid(self):
self.module_list_file = {}
if not os.path.exists(self.module_list_file_path()):
return False
file_create_time = os.path.getctime(self.module_list_file_path())
file_create_time = os.path.getctime(self.resource_list_file_path())
now_time = time.time()
# if file is out of date, remove it
if now_time - file_create_time >= CACHE_TIME:
os.remove(self.module_list_file_path())
os.remove(self.resource_list_file_path())
return False
self.module_list_file = csv_reader.read(self.module_list_file_path())
for resource in yaml_reader.read(
self.resource_list_file_path())['resource_list']:
for key in resource:
if key not in self.resource_list_file:
self.resource_list_file[key] = []
self.resource_list_file[key].append(resource[key])
# if file do not contain necessary data, remove it
if "version" not in self.module_list_file or "module_name" not in self.module_list_file:
self.module_list_file = {}
os.remove(self.module_list_file_path())
if "version" not in self.resource_list_file or "name" not in self.resource_list_file:
self.resource_list_file = {}
os.remove(self.resource_list_file_path())
return False
return True
def search_module(self, module_key, update=False):
if update or not self.module_list_file:
def search_resource(self, resource_key, resource_type=None, update=False):
if update or not self.resource_list_file:
self.request()
match_module_index_list = [
index
for index, module in enumerate(self.module_list_file['module_name'])
if module_key in module
]
return [(self.module_list_file['module_name'][index],
self.module_list_file['version'][index])
for index in match_module_index_list]
def search_model(self, model_key, update=False):
if update or not self.model_list_file:
self.request_model()
match_model_index_list = [
match_resource_index_list = [
index
for index, model in enumerate(self.model_list_file['model_name'])
if model_key in model
for index, resource in enumerate(self.resource_list_file['name'])
if resource_key in resource and (
resource_type is None
or self.resource_list_file['type'][index] == resource_type)
]
return [(self.model_list_file['model_name'][index],
self.model_list_file['version'][index])
for index in match_model_index_list]
return [(self.resource_list_file['name'][index],
self.resource_list_file['type'][index],
self.resource_list_file['version'][index],
self.resource_list_file['summary'][index])
for index in match_resource_index_list]
def get_module_url(self, module_name, version=None, update=False):
if update or not self.module_list_file:
def search_module(self, module_key, update=False):
self.search_resource(
resource_key=module_key, resource_type="Module", update=update)
def search_model(self, module_key, update=False):
self.search_resource(
resource_key=module_key, resource_type="Model", update=update)
def get_resource_url(self,
resource_name,
resource_type=None,
version=None,
update=False):
if update or not self.resource_list_file:
self.request()
module_index_list = [
resource_index_list = [
index
for index, module in enumerate(self.module_list_file['module_name'])
if module == module_name
for index, resource in enumerate(self.resource_list_file['name'])
if resource == resource_name and (
resource_type is None
or self.resource_list_file['type'][index] == resource_type)
]
module_version_list = [
self.module_list_file['version'][index]
for index in module_index_list
resource_version_list = [
self.resource_list_file['version'][index]
for index in resource_index_list
]
#TODO(wuzewu): version sort method
module_version_list = sorted(module_version_list)
resource_version_list = sorted(resource_version_list)
if not version:
if not module_version_list:
if not resource_version_list:
return None
version = module_version_list[-1]
version = resource_version_list[-1]
for index in module_index_list:
if self.module_list_file['version'][index] == version:
return self.module_list_file['url'][index]
for index in resource_index_list:
if self.resource_list_file['version'][index] == version:
return self.resource_list_file['url'][index]
return None
def get_model_url(self, model_name, version=None, update=False):
if update or not self.model_list_file:
self.request_model()
model_index_list = [
index
for index, model in enumerate(self.model_list_file['model_name'])
if model == model_name
]
model_version_list = [
self.model_list_file['version'][index] for index in model_index_list
]
#TODO(wuzewu): version sort method
model_version_list = sorted(model_version_list)
if not version:
if not model_version_list:
return None
version = model_version_list[-1]
for index in model_index_list:
if self.model_list_file['version'][index] == version:
return self.model_list_file['url'][index]
return None
def get_module_url(self, module_name, version=None, update=False):
return self.get_resource_url(
resource_name=module_name,
resource_type="Module",
version=version,
update=update)
def get_model_url(self, module_name, version=None, update=False):
return self.get_resource_url(
resource_name=module_name,
resource_type="Model",
version=version,
update=update)
def request(self):
file_url = self.server_url + MODULE_LIST_FILE
result, tips, self.module_list_file = default_downloader.download_file(
file_url, save_path=hub.CACHE_HOME)
if not result:
return False
return self._load_module_list_file_if_valid()
def request_model(self):
file_url = self.server_url + MODEL_LIST_FILE
result, tips, self.model_list_file = default_downloader.download_file(
file_url = self.server_url + RESOURCE_LIST_FILE
result, tips, self.resource_list_file = default_downloader.download_file(
file_url, save_path=hub.CACHE_HOME)
if not result:
return False
return self._load_model_list_file_if_valid()
return self._load_resource_list_file_if_valid()
default_hub_server = HubServer()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册