diff --git a/paddle_hub/commands/download.py b/paddle_hub/commands/download.py index b06b2c4cc25daf17302131e5490dec1be50b6cd9..ff693d30c621553afb84873043edd6519b628203 100644 --- a/paddle_hub/commands/download.py +++ b/paddle_hub/commands/download.py @@ -29,38 +29,38 @@ class DownloadCommand(BaseCommand): def __init__(self, name): super(DownloadCommand, self).__init__(name) self.show_in_help = True - self.description = "Download a paddle hub module." + self.description = "Download a baidu NLP model." self.parser = self.parser = argparse.ArgumentParser( description=self.__class__.__doc__, prog='%s %s ' % (ENTRY, name), usage='%(prog)s [options]', add_help=False) # yapf: disable - self.add_arg('--output_path', str, ".", "path to save the module" ) + self.add_arg('--output_path', str, ".", "path to save the model" ) self.add_arg('--uncompress', bool, False, "uncompress the download package or not" ) # yapf: enable def exec(self, argv): if not argv: - print("ERROR: Please specify a module\n") + print("ERROR: Please specify a model name\n") self.help() return False - module_name = argv[0] - module_version = None if "==" not in module_name else module_name.split( + model_name = argv[0] + model_version = None if "==" not in model_name else model_name.split( "==")[1] - module_name = module_name if "==" not in module_name else module_name.split( + model_name = model_name if "==" not in model_name else model_name.split( "==")[0] self.args = self.parser.parse_args(argv[1:]) if not self.args.output_path: self.args.output_path = "." utils.check_path(self.args.output_path) - url = default_hub_server.get_module_url( - module_name, version=module_version) + url = default_hub_server.get_model_url( + model_name, version=model_version) if not url: - tips = "can't found module %s" % module_name - if module_version: - tips += " with version %s" % module_version + tips = "can't found model %s" % model_name + if model_version: + tips += " with version %s" % model_version print(tips) return True diff --git a/paddle_hub/commands/run.py b/paddle_hub/commands/run.py index a36c4d264c081677356055e9869c60b14b23f955..3fe8f1a2879726d6e108680df05489e352a5bba6 100644 --- a/paddle_hub/commands/run.py +++ b/paddle_hub/commands/run.py @@ -115,7 +115,7 @@ class RunCommand(BaseCommand): origin_data = csv_reader.read(self.args.dataset) else: print("ERROR! Please specify data to predict.\n") - print("Summary:\n %s" % module.summary) + print("Summary:\n %s\n" % module.summary) print("Example:\n %s" % self.demo_with_module(module)) return False diff --git a/paddle_hub/commands/show.py b/paddle_hub/commands/show.py index 3895621c3f122987b4523b290ea31f3e12aba6e3..58ed0bfda8cb21538ffd0b059b7669ef39c16f6f 100644 --- a/paddle_hub/commands/show.py +++ b/paddle_hub/commands/show.py @@ -23,6 +23,7 @@ from paddle_hub.common.logger import logger from paddle_hub.commands.base_command import BaseCommand, ENTRY from paddle_hub.module.manager import default_module_manager from paddle_hub.module.module import Module +from paddle_hub.io.reader import yaml_reader class ShowCommand(BaseCommand): @@ -40,12 +41,26 @@ class ShowCommand(BaseCommand): def exec(self, argv): if not argv: - print("ERROR: Please specify a module\n") + print("ERROR: Please specify a module or a model\n") self.help() return False module_name = argv[0] + # nlp model + model_info = os.path.join(module_name, "info.yml") + if os.path.exists(model_info): + model_info = yaml_reader.read(model_info) + show_text = "Name:%s\n" % model_info['name'] + show_text += "Type:%s\n" % model_info['type'] + show_text += "Version:%s\n" % model_info['version'] + show_text += "Summary:\n" + show_text += " %s\n" % model_info['description'] + show_text += "Author:%s\n" % model_info['author'] + show_text += "Author-Email:%s\n" % model_info['author_email'] + print(show_text) + return True + cwd = os.getcwd() module_dir = default_module_manager.search_module(module_name) module_dir = os.path.join(cwd, diff --git a/paddle_hub/common/hub_server.py b/paddle_hub/common/hub_server.py index 3793926efe7f6bc073a434933c05a73d16937e7d..cca2d97a387b90d5eb4fd244c9a6c470d4a42c43 100644 --- a/paddle_hub/common/hub_server.py +++ b/paddle_hub/common/hub_server.py @@ -23,6 +23,7 @@ import time import paddle_hub as hub MODULE_LIST_FILE = "module_list_file.csv" +MODEL_LIST_FILE = "model_list_file.csv" CACHE_TIME = 60 * 10 @@ -33,10 +34,34 @@ class HubServer: utils.check_url(server_url) self.server_url = server_url self._load_module_list_file_if_valid() + self._load_model_list_file_if_valid() def module_list_file_path(self): return os.path.join(hub.CACHE_HOME, MODULE_LIST_FILE) + def model_list_file_path(self): + return os.path.join(hub.CACHE_HOME, MODEL_LIST_FILE) + + def _load_model_list_file_if_valid(self): + self.model_list_file = {} + if not os.path.exists(self.model_list_file_path()): + return False + file_create_time = os.path.getctime(self.model_list_file_path()) + now_time = time.time() + + # if file is out of date, remove it + if now_time - file_create_time >= CACHE_TIME: + os.remove(self.model_list_file_path()) + return False + self.model_list_file = csv_reader.read(self.model_list_file_path()) + + # if file do not contain necessary data, remove it + if "version" not in self.model_list_file or "model_name" not in self.model_list_file: + self.model_list_file = {} + os.remove(self.model_list_file_path()) + return False + return True + def _load_module_list_file_if_valid(self): self.module_list_file = {} if not os.path.exists(self.module_list_file_path()): @@ -71,6 +96,20 @@ class HubServer: self.module_list_file['version'][index]) for index in match_module_index_list] + def search_model(self, model_key, update=False): + if update or not self.model_list_file: + self.request_model() + + match_model_index_list = [ + index + for index, model in enumerate(self.model_list_file['model_name']) + if model_key in model + ] + + return [(self.model_list_file['model_name'][index], + self.model_list_file['version'][index]) + for index in match_model_index_list] + def get_module_url(self, module_name, version=None, update=False): if update or not self.module_list_file: self.request() @@ -97,6 +136,31 @@ class HubServer: return None + def get_model_url(self, model_name, version=None, update=False): + if update or not self.model_list_file: + self.request_model() + + model_index_list = [ + index + for index, model in enumerate(self.model_list_file['model_name']) + if model == model_name + ] + model_version_list = [ + self.model_list_file['version'][index] for index in model_index_list + ] + #TODO(wuzewu): version sort method + model_version_list = sorted(model_version_list) + if not version: + if not model_version_list: + return None + version = model_version_list[-1] + + for index in model_index_list: + if self.model_list_file['version'][index] == version: + return self.model_list_file['url'][index] + + return None + def request(self): file_url = self.server_url + MODULE_LIST_FILE result, tips, self.module_list_file = default_downloader.download_file( @@ -105,5 +169,13 @@ class HubServer: return False return self._load_module_list_file_if_valid() + def request_model(self): + file_url = self.server_url + MODEL_LIST_FILE + result, tips, self.model_list_file = default_downloader.download_file( + file_url, save_path=hub.CACHE_HOME) + if not result: + return False + return self._load_model_list_file_if_valid() + default_hub_server = HubServer()