提交 2cc12622 编写于 作者: W wuzewu

update download command

上级 7780e8ef
......@@ -29,38 +29,38 @@ class DownloadCommand(BaseCommand):
def __init__(self, name):
super(DownloadCommand, self).__init__(name)
self.show_in_help = True
self.description = "Download a paddle hub module."
self.description = "Download a baidu NLP model."
self.parser = self.parser = argparse.ArgumentParser(
description=self.__class__.__doc__,
prog='%s %s <module_name>' % (ENTRY, name),
usage='%(prog)s [options]',
add_help=False)
# yapf: disable
self.add_arg('--output_path', str, ".", "path to save the module" )
self.add_arg('--output_path', str, ".", "path to save the model" )
self.add_arg('--uncompress', bool, False, "uncompress the download package or not" )
# yapf: enable
def exec(self, argv):
if not argv:
print("ERROR: Please specify a module\n")
print("ERROR: Please specify a model name\n")
self.help()
return False
module_name = argv[0]
module_version = None if "==" not in module_name else module_name.split(
model_name = argv[0]
model_version = None if "==" not in model_name else model_name.split(
"==")[1]
module_name = module_name if "==" not in module_name else module_name.split(
model_name = model_name if "==" not in model_name else model_name.split(
"==")[0]
self.args = self.parser.parse_args(argv[1:])
if not self.args.output_path:
self.args.output_path = "."
utils.check_path(self.args.output_path)
url = default_hub_server.get_module_url(
module_name, version=module_version)
url = default_hub_server.get_model_url(
model_name, version=model_version)
if not url:
tips = "can't found module %s" % module_name
if module_version:
tips += " with version %s" % module_version
tips = "can't found model %s" % model_name
if model_version:
tips += " with version %s" % model_version
print(tips)
return True
......
......@@ -115,7 +115,7 @@ class RunCommand(BaseCommand):
origin_data = csv_reader.read(self.args.dataset)
else:
print("ERROR! Please specify data to predict.\n")
print("Summary:\n %s" % module.summary)
print("Summary:\n %s\n" % module.summary)
print("Example:\n %s" % self.demo_with_module(module))
return False
......
......@@ -23,6 +23,7 @@ from paddle_hub.common.logger import logger
from paddle_hub.commands.base_command import BaseCommand, ENTRY
from paddle_hub.module.manager import default_module_manager
from paddle_hub.module.module import Module
from paddle_hub.io.reader import yaml_reader
class ShowCommand(BaseCommand):
......@@ -40,12 +41,26 @@ class ShowCommand(BaseCommand):
def exec(self, argv):
if not argv:
print("ERROR: Please specify a module\n")
print("ERROR: Please specify a module or a model\n")
self.help()
return False
module_name = argv[0]
# nlp model
model_info = os.path.join(module_name, "info.yml")
if os.path.exists(model_info):
model_info = yaml_reader.read(model_info)
show_text = "Name:%s\n" % model_info['name']
show_text += "Type:%s\n" % model_info['type']
show_text += "Version:%s\n" % model_info['version']
show_text += "Summary:\n"
show_text += " %s\n" % model_info['description']
show_text += "Author:%s\n" % model_info['author']
show_text += "Author-Email:%s\n" % model_info['author_email']
print(show_text)
return True
cwd = os.getcwd()
module_dir = default_module_manager.search_module(module_name)
module_dir = os.path.join(cwd,
......
......@@ -23,6 +23,7 @@ import time
import paddle_hub as hub
MODULE_LIST_FILE = "module_list_file.csv"
MODEL_LIST_FILE = "model_list_file.csv"
CACHE_TIME = 60 * 10
......@@ -33,10 +34,34 @@ class HubServer:
utils.check_url(server_url)
self.server_url = server_url
self._load_module_list_file_if_valid()
self._load_model_list_file_if_valid()
def module_list_file_path(self):
return os.path.join(hub.CACHE_HOME, MODULE_LIST_FILE)
def model_list_file_path(self):
return os.path.join(hub.CACHE_HOME, MODEL_LIST_FILE)
def _load_model_list_file_if_valid(self):
self.model_list_file = {}
if not os.path.exists(self.model_list_file_path()):
return False
file_create_time = os.path.getctime(self.model_list_file_path())
now_time = time.time()
# if file is out of date, remove it
if now_time - file_create_time >= CACHE_TIME:
os.remove(self.model_list_file_path())
return False
self.model_list_file = csv_reader.read(self.model_list_file_path())
# if file do not contain necessary data, remove it
if "version" not in self.model_list_file or "model_name" not in self.model_list_file:
self.model_list_file = {}
os.remove(self.model_list_file_path())
return False
return True
def _load_module_list_file_if_valid(self):
self.module_list_file = {}
if not os.path.exists(self.module_list_file_path()):
......@@ -71,6 +96,20 @@ class HubServer:
self.module_list_file['version'][index])
for index in match_module_index_list]
def search_model(self, model_key, update=False):
if update or not self.model_list_file:
self.request_model()
match_model_index_list = [
index
for index, model in enumerate(self.model_list_file['model_name'])
if model_key in model
]
return [(self.model_list_file['model_name'][index],
self.model_list_file['version'][index])
for index in match_model_index_list]
def get_module_url(self, module_name, version=None, update=False):
if update or not self.module_list_file:
self.request()
......@@ -97,6 +136,31 @@ class HubServer:
return None
def get_model_url(self, model_name, version=None, update=False):
if update or not self.model_list_file:
self.request_model()
model_index_list = [
index
for index, model in enumerate(self.model_list_file['model_name'])
if model == model_name
]
model_version_list = [
self.model_list_file['version'][index] for index in model_index_list
]
#TODO(wuzewu): version sort method
model_version_list = sorted(model_version_list)
if not version:
if not model_version_list:
return None
version = model_version_list[-1]
for index in model_index_list:
if self.model_list_file['version'][index] == version:
return self.model_list_file['url'][index]
return None
def request(self):
file_url = self.server_url + MODULE_LIST_FILE
result, tips, self.module_list_file = default_downloader.download_file(
......@@ -105,5 +169,13 @@ class HubServer:
return False
return self._load_module_list_file_if_valid()
def request_model(self):
file_url = self.server_url + MODEL_LIST_FILE
result, tips, self.model_list_file = default_downloader.download_file(
file_url, save_path=hub.CACHE_HOME)
if not result:
return False
return self._load_model_list_file_if_valid()
default_hub_server = HubServer()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册