提交 5ff5415f 编写于 作者: S Steffy-zxf 提交者: wuzewu

Fix the bug that user can't download the modules (#5)

* Fix the bug that user can't download the modules
上级 44a9c1ab
...@@ -31,40 +31,44 @@ class DownloadCommand(BaseCommand): ...@@ -31,40 +31,44 @@ class DownloadCommand(BaseCommand):
def __init__(self, name): def __init__(self, name):
super(DownloadCommand, self).__init__(name) super(DownloadCommand, self).__init__(name)
self.show_in_help = True self.show_in_help = True
self.description = "Download PaddlePaddle pretrained model files." self.description = "Download PaddlePaddle pretrained model/module files."
self.parser = self.parser = argparse.ArgumentParser( self.parser = self.parser = argparse.ArgumentParser(
description=self.__class__.__doc__, description=self.__class__.__doc__,
prog='%s %s <model_name>' % (ENTRY, name), prog='%s %s <model_name/module_name>' % (ENTRY, name),
usage='%(prog)s [options]', usage='%(prog)s [options]',
add_help=False) add_help=False)
# yapf: disable # yapf: disable
self.add_arg('--output_path', str, ".", "path to save the model" ) self.add_arg("--type", str, "All", "choice: Module/Model/All")
self.add_arg('--output_path', str, ".", "path to save the model/module" )
self.add_arg('--uncompress', bool, False, "uncompress the download package or not" ) self.add_arg('--uncompress', bool, False, "uncompress the download package or not" )
# yapf: enable # yapf: enable
def exec(self, argv): def exec(self, argv):
if not argv: if not argv:
print("ERROR: Please provide the model name\n") print("ERROR: Please provide the model/module name\n")
self.help() self.help()
return False return False
model_name = argv[0] mod_name = argv[0]
model_version = None if "==" not in model_name else model_name.split( mod_version = None if "==" not in mod_name else mod_name.split("==")[1]
"==")[1] mod_name = mod_name if "==" not in mod_name else mod_name.split("==")[0]
model_name = model_name if "==" not in model_name else model_name.split(
"==")[0]
self.args = self.parser.parse_args(argv[1:]) self.args = self.parser.parse_args(argv[1:])
if not self.args.output_path: self.args.type = self.check_type(self.args.type)
self.args.output_path = "."
utils.check_path(self.args.output_path)
search_result = default_hub_server.get_model_url( if self.args.type in ["Module", "Model"]:
model_name, version=model_version) search_result = default_hub_server.get_resource_url(
mod_name, resource_type=self.args.type, version=mod_version)
else:
search_result = default_hub_server.get_resource_url(
mod_name, resource_type="Module", version=mod_version)
if search_result == {}:
search_result = default_hub_server.get_resource_url(
mod_name, resource_type="Model", version=mod_version)
url = search_result.get('url', None) url = search_result.get('url', None)
except_md5_value = search_result.get('md5', None) except_md5_value = search_result.get('md5', None)
if not url: if not url:
tips = "Can't found model %s" % model_name tips = "Can't found model/module %s" % mod_name
if model_version: if model_version:
tips += " with version %s" % model_version tips += " with version %s" % mod_version
print(tips) print(tips)
return True return True
...@@ -93,10 +97,20 @@ class DownloadCommand(BaseCommand): ...@@ -93,10 +97,20 @@ class DownloadCommand(BaseCommand):
result, tips, file = default_downloader.uncompress( result, tips, file = default_downloader.uncompress(
file=file, file=file,
dirname=self.args.output_path, dirname=self.args.output_path,
delete_file=True, delete_file=False,
print_progress=True) print_progress=True)
print(tips) print(tips)
return True return True
def check_type(self, mod_type):
mod_type = mod_type.lower()
if mod_type == "module":
mod_type = "Module"
elif mod_type == "model":
mod_type = "Model"
else:
mod_type = "All"
return mod_type
command = DownloadCommand.instance() command = DownloadCommand.instance()
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import argparse import argparse
import os import os
import sys
from paddlehub.commands.base_command import BaseCommand, ENTRY from paddlehub.commands.base_command import BaseCommand, ENTRY
from paddlehub.io.parser import yaml_parser, txt_parser from paddlehub.io.parser import yaml_parser, txt_parser
...@@ -93,7 +94,14 @@ class RunCommand(BaseCommand): ...@@ -93,7 +94,14 @@ class RunCommand(BaseCommand):
if not result: if not result:
return False return False
module = hub.Module(module_dir=module_dir) try:
module = hub.Module(module_dir=module_dir)
except:
print(
"ERROR! %s is a model. The command is only for the module type but not the model type."
% module_name)
sys.exit(0)
self.parse_args_with_module(module, argv[1:]) self.parse_args_with_module(module, argv[1:])
if not module.default_signature: if not module.default_signature:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册