提交 270e44ab 编写于 作者: W wuzewu

search command add colorful output

上级 4b4b48ef
...@@ -22,6 +22,7 @@ from paddlehub.common.logger import logger ...@@ -22,6 +22,7 @@ from paddlehub.common.logger import logger
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common.hub_server import default_hub_server from paddlehub.common.hub_server import default_hub_server
from paddlehub.commands.base_command import BaseCommand, ENTRY from paddlehub.commands.base_command import BaseCommand, ENTRY
from paddlehub.commands.cml_utils import TablePrinter
class SearchCommand(BaseCommand): class SearchCommand(BaseCommand):
...@@ -43,15 +44,23 @@ class SearchCommand(BaseCommand): ...@@ -43,15 +44,23 @@ class SearchCommand(BaseCommand):
self.help() self.help()
return False return False
module_name = argv[0] resource_name = argv[0]
module_list = default_hub_server.search_module(module_name) resource_list = default_hub_server.search_resource(resource_name)
text = "\n" tp = TablePrinter(
text += color_bold_text( titles=["ResourceName", "Type", "Version", "Summary"],
"red", " %-20s\t\t%s\n" % ("ModuleName", "ModuleVersion")) placeholders=[25, 10, 10, 35])
text += " %-20s\t\t%s\n" % ("--", "--") for resource_name, resource_type, resource_version, resource_summary in resource_list:
for module_name, module_version in module_list: if resource_type == "Module":
text += " %-20s\t\t%s\n" % (module_name, module_version) colors = ["yellow", None, None, None]
print(text) else:
colors = ["light_red", None, None, None]
tp.add_line(
contents=[
resource_name, resource_version, resource_type,
resource_summary
],
colors=colors)
print(tp.get_text())
return True return True
......
...@@ -21,11 +21,10 @@ import time ...@@ -21,11 +21,10 @@ import time
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common.downloader import default_downloader from paddlehub.common.downloader import default_downloader
from paddlehub.io.reader import csv_reader from paddlehub.io.reader import yaml_reader
import paddlehub as hub import paddlehub as hub
MODULE_LIST_FILE = "module_list_file.csv" RESOURCE_LIST_FILE = "resource_list_file.yml"
MODEL_LIST_FILE = "model_list_file.csv"
CACHE_TIME = 60 * 10 CACHE_TIME = 60 * 10
...@@ -35,149 +34,115 @@ class HubServer: ...@@ -35,149 +34,115 @@ class HubServer:
server_url = "https://paddlehub.bj.bcebos.com/" server_url = "https://paddlehub.bj.bcebos.com/"
utils.check_url(server_url) utils.check_url(server_url)
self.server_url = server_url self.server_url = server_url
self._load_module_list_file_if_valid() self._load_resource_list_file_if_valid()
self._load_model_list_file_if_valid()
def module_list_file_path(self): def resource_list_file_path(self):
return os.path.join(hub.CACHE_HOME, MODULE_LIST_FILE) return os.path.join(hub.CACHE_HOME, RESOURCE_LIST_FILE)
def model_list_file_path(self): def _load_resource_list_file_if_valid(self):
return os.path.join(hub.CACHE_HOME, MODEL_LIST_FILE) self.resource_list_file = {}
if not os.path.exists(self.resource_list_file_path()):
def _load_model_list_file_if_valid(self):
self.model_list_file = {}
if not os.path.exists(self.model_list_file_path()):
return False
file_create_time = os.path.getctime(self.model_list_file_path())
now_time = time.time()
# if file is out of date, remove it
if now_time - file_create_time >= CACHE_TIME:
os.remove(self.model_list_file_path())
return False
self.model_list_file = csv_reader.read(self.model_list_file_path())
# if file do not contain necessary data, remove it
if "version" not in self.model_list_file or "model_name" not in self.model_list_file:
self.model_list_file = {}
os.remove(self.model_list_file_path())
return False return False
return True file_create_time = os.path.getctime(self.resource_list_file_path())
def _load_module_list_file_if_valid(self):
self.module_list_file = {}
if not os.path.exists(self.module_list_file_path()):
return False
file_create_time = os.path.getctime(self.module_list_file_path())
now_time = time.time() now_time = time.time()
# if file is out of date, remove it # if file is out of date, remove it
if now_time - file_create_time >= CACHE_TIME: if now_time - file_create_time >= CACHE_TIME:
os.remove(self.module_list_file_path()) os.remove(self.resource_list_file_path())
return False return False
self.module_list_file = csv_reader.read(self.module_list_file_path()) for resource in yaml_reader.read(
self.resource_list_file_path())['resource_list']:
for key in resource:
if key not in self.resource_list_file:
self.resource_list_file[key] = []
self.resource_list_file[key].append(resource[key])
# if file do not contain necessary data, remove it # if file do not contain necessary data, remove it
if "version" not in self.module_list_file or "module_name" not in self.module_list_file: if "version" not in self.resource_list_file or "name" not in self.resource_list_file:
self.module_list_file = {} self.resource_list_file = {}
os.remove(self.module_list_file_path()) os.remove(self.resource_list_file_path())
return False return False
return True return True
def search_module(self, module_key, update=False): def search_resource(self, resource_key, resource_type=None, update=False):
if update or not self.module_list_file: if update or not self.resource_list_file:
self.request() self.request()
match_module_index_list = [ match_resource_index_list = [
index
for index, module in enumerate(self.module_list_file['module_name'])
if module_key in module
]
return [(self.module_list_file['module_name'][index],
self.module_list_file['version'][index])
for index in match_module_index_list]
def search_model(self, model_key, update=False):
if update or not self.model_list_file:
self.request_model()
match_model_index_list = [
index index
for index, model in enumerate(self.model_list_file['model_name']) for index, resource in enumerate(self.resource_list_file['name'])
if model_key in model if resource_key in resource and (
resource_type is None
or self.resource_list_file['type'][index] == resource_type)
] ]
return [(self.model_list_file['model_name'][index], return [(self.resource_list_file['name'][index],
self.model_list_file['version'][index]) self.resource_list_file['type'][index],
for index in match_model_index_list] self.resource_list_file['version'][index],
self.resource_list_file['summary'][index])
for index in match_resource_index_list]
def get_module_url(self, module_name, version=None, update=False): def search_module(self, module_key, update=False):
if update or not self.module_list_file: self.search_resource(
resource_key=module_key, resource_type="Module", update=update)
def search_model(self, module_key, update=False):
self.search_resource(
resource_key=module_key, resource_type="Model", update=update)
def get_resource_url(self,
resource_name,
resource_type=None,
version=None,
update=False):
if update or not self.resource_list_file:
self.request() self.request()
module_index_list = [ resource_index_list = [
index index
for index, module in enumerate(self.module_list_file['module_name']) for index, resource in enumerate(self.resource_list_file['name'])
if module == module_name if resource == resource_name and (
resource_type is None
or self.resource_list_file['type'][index] == resource_type)
] ]
module_version_list = [ resource_version_list = [
self.module_list_file['version'][index] self.resource_list_file['version'][index]
for index in module_index_list for index in resource_index_list
] ]
#TODO(wuzewu): version sort method #TODO(wuzewu): version sort method
module_version_list = sorted(module_version_list) resource_version_list = sorted(resource_version_list)
if not version: if not version:
if not module_version_list: if not resource_version_list:
return None return None
version = module_version_list[-1] version = resource_version_list[-1]
for index in module_index_list: for index in resource_index_list:
if self.module_list_file['version'][index] == version: if self.resource_list_file['version'][index] == version:
return self.module_list_file['url'][index] return self.resource_list_file['url'][index]
return None return None
def get_model_url(self, model_name, version=None, update=False): def get_module_url(self, module_name, version=None, update=False):
if update or not self.model_list_file: return self.get_resource_url(
self.request_model() resource_name=module_name,
resource_type="Module",
model_index_list = [ version=version,
index update=update)
for index, model in enumerate(self.model_list_file['model_name'])
if model == model_name def get_model_url(self, module_name, version=None, update=False):
] return self.get_resource_url(
model_version_list = [ resource_name=module_name,
self.model_list_file['version'][index] for index in model_index_list resource_type="Model",
] version=version,
#TODO(wuzewu): version sort method update=update)
model_version_list = sorted(model_version_list)
if not version:
if not model_version_list:
return None
version = model_version_list[-1]
for index in model_index_list:
if self.model_list_file['version'][index] == version:
return self.model_list_file['url'][index]
return None
def request(self): def request(self):
file_url = self.server_url + MODULE_LIST_FILE file_url = self.server_url + RESOURCE_LIST_FILE
result, tips, self.module_list_file = default_downloader.download_file( result, tips, self.resource_list_file = default_downloader.download_file(
file_url, save_path=hub.CACHE_HOME)
if not result:
return False
return self._load_module_list_file_if_valid()
def request_model(self):
file_url = self.server_url + MODEL_LIST_FILE
result, tips, self.model_list_file = default_downloader.download_file(
file_url, save_path=hub.CACHE_HOME) file_url, save_path=hub.CACHE_HOME)
if not result: if not result:
return False return False
return self._load_model_list_file_if_valid() return self._load_resource_list_file_if_valid()
default_hub_server = HubServer() default_hub_server = HubServer()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册