提交 e93b3b67 编写于 作者: W wuzewu

Optimize the run command to support model custom parameter parsing

上级 f4e67a1a
...@@ -23,6 +23,8 @@ import os ...@@ -23,6 +23,8 @@ import os
import sys import sys
import six import six
import pandas
import numpy as np
from paddlehub.commands.base_command import BaseCommand, ENTRY from paddlehub.commands.base_command import BaseCommand, ENTRY
from paddlehub.io.parser import yaml_parser, txt_parser from paddlehub.io.parser import yaml_parser, txt_parser
...@@ -32,60 +34,27 @@ from paddlehub.common.arg_helper import add_argument, print_arguments ...@@ -32,60 +34,27 @@ from paddlehub.common.arg_helper import add_argument, print_arguments
import paddlehub as hub import paddlehub as hub
class DataFormatError(Exception):
def __init__(self, *args):
self.args = args
class RunCommand(BaseCommand): class RunCommand(BaseCommand):
name = "run" name = "run"
def __init__(self, name): def __init__(self, name):
super(RunCommand, self).__init__(name) super(RunCommand, self).__init__(name)
self.show_in_help = True self.show_in_help = True
self.name = name
self.description = "Run the specific module." self.description = "Run the specific module."
self.parser = self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
description=self.__class__.__doc__, description=self.__class__.__doc__,
prog='%s %s <module>' % (ENTRY, name), prog='%s %s <module>' % (ENTRY, self.name),
usage='%(prog)s', usage='%(prog)s',
add_help=False) add_help=False)
self.module = None
def parse_args_with_module(self, module, argv): def find_module(self, module_name):
module_type = module.type.lower()
# yapf: disable
if module_type.startswith("cv"):
self.add_arg('--config', str, None, "config file in yaml format" )
self.add_arg('--signature', str, None, "signature to run" )
self.add_arg('--input_path', str, None, "path of image to predict" )
self.add_arg('--input_file', str, None, "file contain paths of images" )
self.args = self.parser.parse_args(argv)
self.args.data = self.args.input_path
self.args.dataset = self.args.input_file
elif module_type.startswith("nlp"):
self.add_arg('--config', str, None, "config file in yaml format" )
self.add_arg('--signature', str, None, "signature to run" )
self.add_arg('--input_text', str, None, "text to predict" )
self.add_arg('--input_file', str, None, "file contain texts" )
self.args = self.parser.parse_args(argv)
self.args.data = self.args.input_text
self.args.dataset = self.args.input_file
# yapf: enable
def demo_with_module(self, module):
module_type = module.type.lower()
entry = hub.commands.base_command.ENTRY
if module_type.startswith("cv"):
demo = "%s %s %s --input_path <IMAGE_PATH>" % (entry, self.name,
module.name)
elif module_type.startswith("nlp"):
demo = "%s %s %s --input_text \"TEXT_TO_PREDICT\"" % (
entry, self.name, module.name)
else:
demo = "%s %s %s" % (entry, self.name, module.name)
return demo
def execute(self, argv):
if not argv:
print("ERROR: Please specify a module name.\n")
self.help()
return False
module_name = argv[0]
module_dir = default_module_manager.search_module(module_name) module_dir = default_module_manager.search_module(module_name)
if not module_dir: if not module_dir:
if os.path.exists(module_name): if os.path.exists(module_name):
...@@ -96,87 +65,177 @@ class RunCommand(BaseCommand): ...@@ -96,87 +65,177 @@ class RunCommand(BaseCommand):
module_name) module_name)
print(tips) print(tips)
if not result: if not result:
return False return None
return hub.Module(module_dir=module_dir)
def add_module_config_arg(self):
configs = self.module.processor.configs()
for config in configs:
if not config["dest"].startswith("--"):
config["dest"] = "--%s" % config["dest"]
self.arg_config_group.add_argument(
config["dest"],
type=config['type'],
default=config['default'],
help=config['help'])
self.arg_config_group.add_argument(
'--config',
type=str,
default=None,
help="config file in yaml format")
def add_module_input_arg(self):
module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format(
self.module.default_signature.name)
self.arg_input_group.add_argument(
'--input_file',
type=str,
default=None,
help="file contain input data")
if len(expect_data_format) == 1:
if module_type.startswith("cv"):
self.arg_input_group.add_argument(
'--input_path',
type=str,
default=None,
help="path of image/video to predict")
elif module_type.startswith("nlp"):
self.arg_input_group.add_argument(
'--input_text',
type=str,
default=None,
help="text to predict")
else:
for key in expect_data_format.keys():
help_str = None
if 'help' in expect_data_format[key]:
help_str = expect_data_format[key]['help']
self.arg_input_group.add_argument(
"--%s" % key, type=str, default=None, help=help_str)
def get_config(self):
yaml_config = {}
if self.args.config:
yaml_config = yaml_parser.parse(self.args.config)
module_config = yaml_config.get("config", {})
for _config in self.module.processor.configs():
key = _config['dest']
module_config[key] = self.args.__dict__[key]
return module_config
def get_data(self):
module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format(
self.module.default_signature.name)
input_data = {}
if len(expect_data_format) == 1:
key = list(expect_data_format.keys())[0]
if self.args.input_file:
input_data[key] = txt_parser.parse(self.args.input_file)
else:
if module_type.startswith("cv"):
input_data[key] = [self.args.input_path]
elif module_type.startswith("nlp"):
input_data[key] = [self.args.input_text]
else:
for key in expect_data_format.keys():
input_data[key] = [self.args.__dict__[key]]
try: if self.args.input_file:
module = hub.Module(module_dir=module_dir) input_data = pandas.read_csv(self.args.input_file, sep="\t")
except:
return input_data
def check_data(self, data):
expect_data_format = self.module.processor.data_format(
self.module.default_signature.name)
if len(data.keys()) != len(expect_data_format.keys()):
print( print(
"ERROR! %s is a model. The command run is only for the module type but not the model type." "ERROR: The number of keys in input file is inconsistent with expectations."
% module_name) )
sys.exit(0) raise DataFormatError
if isinstance(data, pandas.DataFrame):
if data.isnull().sum().sum() != 0:
print(
"ERROR: The number of values in input file is inconsistent with expectations."
)
raise DataFormatError
for key, values in data.items():
if not key in expect_data_format.keys():
print("ERROR! Key <%s> in input file is unexpected.\n" % key)
raise DataFormatError
for value in values:
if not value:
print(
"ERROR: The number of values in input file is inconsistent with expectations."
)
raise DataFormatError
self.parse_args_with_module(module, argv[1:]) def execute(self, argv):
if not argv:
print("ERROR: Please specify a module name.\n")
self.help()
return False
module_name = argv[0]
if not module.default_signature: self.parser.prog = '%s %s %s' % (ENTRY, self.name, module_name)
print("ERROR! Module %s is not able to predict." % module_name) self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Data input to the module")
self.arg_config_group = self.parser.add_argument_group(
title="Config options",
description=
"Run configuration for controlling module behavior, not required")
self.module = self.find_module(module_name)
if not self.module:
return False return False
if not self.args.signature: # If the module is not executable, give an alarm and exit
self.args.signature = module.default_signature.name if not self.module.default_signature:
# module processor check print("ERROR! Module %s is not executable." % module_name)
module.check_processor()
expect_data_format = module.processor.data_format(self.args.signature)
# get data dict
if self.args.data:
input_data_key = list(expect_data_format.keys())[0]
origin_data = {input_data_key: [self.args.data]}
elif self.args.dataset:
input_data_key = list(expect_data_format.keys())[0]
origin_data = {input_data_key: txt_parser.parse(self.args.dataset)}
else:
print("ERROR! Please specify data to predict.\n")
print("Summary:\n %s\n" % module.summary)
print("Example:\n %s" % self.demo_with_module(module))
return False return False
# data_format check self.module.check_processor()
if not self.args.config: self.add_module_config_arg()
if len(expect_data_format) != 1: self.add_module_input_arg()
raise RuntimeError(
"Module requires %d inputs, please use config file to specify mappings for data and inputs." if not argv[1:]:
% len(expect_data_format)) self.help()
origin_data_key = list(origin_data.keys())[0] return False
input_data_key = list(expect_data_format.keys())[0]
input_data = {input_data_key: origin_data[origin_data_key]} self.args = self.parser.parse_args(argv[1:])
config = {}
else: config = self.get_config()
yaml_config = yaml_parser.parse(self.args.config) data = self.get_data()
if len(expect_data_format) == 1:
origin_data_key = list(origin_data.keys())[0] try:
input_data_key = list(expect_data_format.keys())[0] self.check_data(data)
input_data = {input_data_key: origin_data[origin_data_key]} except DataFormatError:
else: self.help()
input_data_format = yaml_config['input_data'] return False
if len(input_data_format) != len(expect_data_format):
raise ValueError( results = self.module(
"Module requires %d inputs, but the input file gives %d." sign_name=self.module.default_signature.name, data=data, **config)
% (len(expect_data_format), len(input_data_format)))
for key, value in expect_data_format.items():
if key not in input_data_format:
raise KeyError(
"Input file gives an unexpected input %s" % key)
if value['type'] != hub.DataType.type(
input_data_format[key]['type']):
raise TypeError(
"Module expect Type %s for %s, but the input file gives %s"
% (value['type'], key,
hub.DataType.type(
input_data_format[key]['type'])))
input_data = {}
for key, value in yaml_config['input_data'].items():
input_data[key] = origin_data[value['key']]
config = yaml_config.get("config", {})
# run module with data
results = module(
sign_name=self.args.signature, data=input_data, **config)
if six.PY2: if six.PY2:
print(json.dumps(results, encoding="utf8", ensure_ascii=False)) try:
else: results = json.dumps(
print(results) results, encoding="utf8", ensure_ascii=False)
except:
pass
print(results)
return True
command = RunCommand.instance() command = RunCommand.instance()
...@@ -30,7 +30,7 @@ def add_argument(argument, type, default, help, argparser, **kwargs): ...@@ -30,7 +30,7 @@ def add_argument(argument, type, default, help, argparser, **kwargs):
argument, argument,
default=default, default=default,
type=type, type=type,
help=help + ' Default: %(default)s.', help=help + ' Default: %(default)s.' if help else help,
**kwargs) **kwargs)
......
...@@ -22,6 +22,9 @@ class BaseProcessor(object): ...@@ -22,6 +22,9 @@ class BaseProcessor(object):
def __init__(self, module): def __init__(self, module):
pass pass
def configs(self):
return []
def preprocess(self, sign_name, data_dict): def preprocess(self, sign_name, data_dict):
raise NotImplementedError( raise NotImplementedError(
"BaseProcessor' preprocess should not be called!") "BaseProcessor' preprocess should not be called!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册