diff --git a/paddlehub/commands/run.py b/paddlehub/commands/run.py index d1b837aab4e656284ff20ae7a4e54207de740e59..9656d6fd8e8f1c0216ba0e33c33959a62b2ed17d 100644 --- a/paddlehub/commands/run.py +++ b/paddlehub/commands/run.py @@ -23,6 +23,8 @@ import os import sys import six +import pandas +import numpy as np from paddlehub.commands.base_command import BaseCommand, ENTRY from paddlehub.io.parser import yaml_parser, txt_parser @@ -32,60 +34,27 @@ from paddlehub.common.arg_helper import add_argument, print_arguments import paddlehub as hub +class DataFormatError(Exception): + def __init__(self, *args): + self.args = args + + class RunCommand(BaseCommand): name = "run" def __init__(self, name): super(RunCommand, self).__init__(name) self.show_in_help = True + self.name = name self.description = "Run the specific module." - self.parser = self.parser = argparse.ArgumentParser( + self.parser = argparse.ArgumentParser( description=self.__class__.__doc__, - prog='%s %s ' % (ENTRY, name), + prog='%s %s ' % (ENTRY, self.name), usage='%(prog)s', add_help=False) + self.module = None - def parse_args_with_module(self, module, argv): - module_type = module.type.lower() - # yapf: disable - if module_type.startswith("cv"): - self.add_arg('--config', str, None, "config file in yaml format" ) - self.add_arg('--signature', str, None, "signature to run" ) - self.add_arg('--input_path', str, None, "path of image to predict" ) - self.add_arg('--input_file', str, None, "file contain paths of images" ) - self.args = self.parser.parse_args(argv) - self.args.data = self.args.input_path - self.args.dataset = self.args.input_file - elif module_type.startswith("nlp"): - self.add_arg('--config', str, None, "config file in yaml format" ) - self.add_arg('--signature', str, None, "signature to run" ) - self.add_arg('--input_text', str, None, "text to predict" ) - self.add_arg('--input_file', str, None, "file contain texts" ) - self.args = self.parser.parse_args(argv) - self.args.data = self.args.input_text - self.args.dataset = self.args.input_file - # yapf: enable - - def demo_with_module(self, module): - module_type = module.type.lower() - entry = hub.commands.base_command.ENTRY - if module_type.startswith("cv"): - demo = "%s %s %s --input_path " % (entry, self.name, - module.name) - elif module_type.startswith("nlp"): - demo = "%s %s %s --input_text \"TEXT_TO_PREDICT\"" % ( - entry, self.name, module.name) - else: - demo = "%s %s %s" % (entry, self.name, module.name) - return demo - - def execute(self, argv): - if not argv: - print("ERROR: Please specify a module name.\n") - self.help() - return False - module_name = argv[0] - + def find_module(self, module_name): module_dir = default_module_manager.search_module(module_name) if not module_dir: if os.path.exists(module_name): @@ -96,87 +65,177 @@ class RunCommand(BaseCommand): module_name) print(tips) if not result: - return False + return None + + return hub.Module(module_dir=module_dir) + + def add_module_config_arg(self): + configs = self.module.processor.configs() + for config in configs: + if not config["dest"].startswith("--"): + config["dest"] = "--%s" % config["dest"] + self.arg_config_group.add_argument( + config["dest"], + type=config['type'], + default=config['default'], + help=config['help']) + + self.arg_config_group.add_argument( + '--config', + type=str, + default=None, + help="config file in yaml format") + + def add_module_input_arg(self): + module_type = self.module.type.lower() + expect_data_format = self.module.processor.data_format( + self.module.default_signature.name) + self.arg_input_group.add_argument( + '--input_file', + type=str, + default=None, + help="file contain input data") + if len(expect_data_format) == 1: + if module_type.startswith("cv"): + self.arg_input_group.add_argument( + '--input_path', + type=str, + default=None, + help="path of image/video to predict") + elif module_type.startswith("nlp"): + self.arg_input_group.add_argument( + '--input_text', + type=str, + default=None, + help="text to predict") + else: + for key in expect_data_format.keys(): + help_str = None + if 'help' in expect_data_format[key]: + help_str = expect_data_format[key]['help'] + self.arg_input_group.add_argument( + "--%s" % key, type=str, default=None, help=help_str) + + def get_config(self): + yaml_config = {} + if self.args.config: + yaml_config = yaml_parser.parse(self.args.config) + module_config = yaml_config.get("config", {}) + for _config in self.module.processor.configs(): + key = _config['dest'] + module_config[key] = self.args.__dict__[key] + return module_config + + def get_data(self): + module_type = self.module.type.lower() + expect_data_format = self.module.processor.data_format( + self.module.default_signature.name) + input_data = {} + if len(expect_data_format) == 1: + key = list(expect_data_format.keys())[0] + if self.args.input_file: + input_data[key] = txt_parser.parse(self.args.input_file) + else: + if module_type.startswith("cv"): + input_data[key] = [self.args.input_path] + elif module_type.startswith("nlp"): + input_data[key] = [self.args.input_text] + else: + for key in expect_data_format.keys(): + input_data[key] = [self.args.__dict__[key]] - try: - module = hub.Module(module_dir=module_dir) - except: + if self.args.input_file: + input_data = pandas.read_csv(self.args.input_file, sep="\t") + + return input_data + + def check_data(self, data): + expect_data_format = self.module.processor.data_format( + self.module.default_signature.name) + + if len(data.keys()) != len(expect_data_format.keys()): print( - "ERROR! %s is a model. The command run is only for the module type but not the model type." - % module_name) - sys.exit(0) + "ERROR: The number of keys in input file is inconsistent with expectations." + ) + raise DataFormatError + + if isinstance(data, pandas.DataFrame): + if data.isnull().sum().sum() != 0: + print( + "ERROR: The number of values in input file is inconsistent with expectations." + ) + raise DataFormatError + + for key, values in data.items(): + + if not key in expect_data_format.keys(): + print("ERROR! Key <%s> in input file is unexpected.\n" % key) + raise DataFormatError + + for value in values: + if not value: + print( + "ERROR: The number of values in input file is inconsistent with expectations." + ) + raise DataFormatError - self.parse_args_with_module(module, argv[1:]) + def execute(self, argv): + if not argv: + print("ERROR: Please specify a module name.\n") + self.help() + return False + + module_name = argv[0] - if not module.default_signature: - print("ERROR! Module %s is not able to predict." % module_name) + self.parser.prog = '%s %s %s' % (ENTRY, self.name, module_name) + self.arg_input_group = self.parser.add_argument_group( + title="Input options", description="Data input to the module") + self.arg_config_group = self.parser.add_argument_group( + title="Config options", + description= + "Run configuration for controlling module behavior, not required") + + self.module = self.find_module(module_name) + if not self.module: return False - if not self.args.signature: - self.args.signature = module.default_signature.name - # module processor check - module.check_processor() - expect_data_format = module.processor.data_format(self.args.signature) - - # get data dict - if self.args.data: - input_data_key = list(expect_data_format.keys())[0] - origin_data = {input_data_key: [self.args.data]} - elif self.args.dataset: - input_data_key = list(expect_data_format.keys())[0] - origin_data = {input_data_key: txt_parser.parse(self.args.dataset)} - else: - print("ERROR! Please specify data to predict.\n") - print("Summary:\n %s\n" % module.summary) - print("Example:\n %s" % self.demo_with_module(module)) + # If the module is not executable, give an alarm and exit + if not self.module.default_signature: + print("ERROR! Module %s is not executable." % module_name) return False - # data_format check - if not self.args.config: - if len(expect_data_format) != 1: - raise RuntimeError( - "Module requires %d inputs, please use config file to specify mappings for data and inputs." - % len(expect_data_format)) - origin_data_key = list(origin_data.keys())[0] - input_data_key = list(expect_data_format.keys())[0] - input_data = {input_data_key: origin_data[origin_data_key]} - config = {} - else: - yaml_config = yaml_parser.parse(self.args.config) - if len(expect_data_format) == 1: - origin_data_key = list(origin_data.keys())[0] - input_data_key = list(expect_data_format.keys())[0] - input_data = {input_data_key: origin_data[origin_data_key]} - else: - input_data_format = yaml_config['input_data'] - if len(input_data_format) != len(expect_data_format): - raise ValueError( - "Module requires %d inputs, but the input file gives %d." - % (len(expect_data_format), len(input_data_format))) - for key, value in expect_data_format.items(): - if key not in input_data_format: - raise KeyError( - "Input file gives an unexpected input %s" % key) - - if value['type'] != hub.DataType.type( - input_data_format[key]['type']): - raise TypeError( - "Module expect Type %s for %s, but the input file gives %s" - % (value['type'], key, - hub.DataType.type( - input_data_format[key]['type']))) - - input_data = {} - for key, value in yaml_config['input_data'].items(): - input_data[key] = origin_data[value['key']] - config = yaml_config.get("config", {}) - # run module with data - results = module( - sign_name=self.args.signature, data=input_data, **config) + self.module.check_processor() + self.add_module_config_arg() + self.add_module_input_arg() + + if not argv[1:]: + self.help() + return False + + self.args = self.parser.parse_args(argv[1:]) + + config = self.get_config() + data = self.get_data() + + try: + self.check_data(data) + except DataFormatError: + self.help() + return False + + results = self.module( + sign_name=self.module.default_signature.name, data=data, **config) + if six.PY2: - print(json.dumps(results, encoding="utf8", ensure_ascii=False)) - else: - print(results) + try: + results = json.dumps( + results, encoding="utf8", ensure_ascii=False) + except: + pass + + print(results) + + return True command = RunCommand.instance() diff --git a/paddlehub/common/arg_helper.py b/paddlehub/common/arg_helper.py index 9d8008b522cb010c1d555dbbe6a9fd0347c7d7c5..27ee88ad6b94fb680ce57777d653d4162790babf 100644 --- a/paddlehub/common/arg_helper.py +++ b/paddlehub/common/arg_helper.py @@ -30,7 +30,7 @@ def add_argument(argument, type, default, help, argparser, **kwargs): argument, default=default, type=type, - help=help + ' Default: %(default)s.', + help=help + ' Default: %(default)s.' if help else help, **kwargs) diff --git a/paddlehub/module/base_processor.py b/paddlehub/module/base_processor.py index d5a57d9f7bf0f50f1e53409355001c01a09fabbb..694819cd47284fa798c4fe45dfaaf520757d6ac5 100644 --- a/paddlehub/module/base_processor.py +++ b/paddlehub/module/base_processor.py @@ -22,6 +22,9 @@ class BaseProcessor(object): def __init__(self, module): pass + def configs(self): + return [] + def preprocess(self, sign_name, data_dict): raise NotImplementedError( "BaseProcessor' preprocess should not be called!")