diff --git a/paddlehub/commands/run.py b/paddlehub/commands/run.py index 60049ee8a5ffba3fd103b9cda3e5e4786d1fd689..9330af53f6545ad3dacffb71b6fbd09cdc24f09e 100644 --- a/paddlehub/commands/run.py +++ b/paddlehub/commands/run.py @@ -74,7 +74,7 @@ class RunCommand(BaseCommand): return hub.Module(directory=module_dir[0]) def add_module_config_arg(self): - configs = self.module.configs() + configs = self.module.processor.configs() for config in configs: if not config["dest"].startswith("--"): config["dest"] = "--%s" % config["dest"] @@ -104,7 +104,7 @@ class RunCommand(BaseCommand): def add_module_input_arg(self): module_type = self.module.type.lower() - expect_data_format = self.module.data_format( + expect_data_format = self.module.processor.data_format( self.module.default_signature) self.arg_input_group.add_argument( '--input_file', @@ -144,14 +144,14 @@ class RunCommand(BaseCommand): if self.args.config: yaml_config = yaml_parser.parse(self.args.config) module_config = yaml_config.get("config", {}) - for _config in self.module.configs(): + for _config in self.module.processor.configs(): key = _config['dest'] module_config[key] = self.args.__dict__[key] return module_config def get_data(self): module_type = self.module.type.lower() - expect_data_format = self.module.data_format( + expect_data_format = self.module.processor.data_format( self.module.default_signature) input_data = {} if len(expect_data_format) == 1: @@ -176,7 +176,7 @@ class RunCommand(BaseCommand): return input_data def check_data(self, data): - expect_data_format = self.module.data_format( + expect_data_format = self.module.processor.data_format( self.module.default_signature) if len(data.keys()) != len(expect_data_format.keys()): @@ -236,35 +236,38 @@ class RunCommand(BaseCommand): return False # If the module is not executable, give an alarm and exit - if not self.module.default_signature: + if not self.module.is_runable: print("ERROR! Module %s is not executable." % module_name) return False - self.module.check_processor() - self.add_module_config_arg() - self.add_module_input_arg() + if self.module.code_version == "v2": + result = self.module(argv[1:]) + else: + self.module.check_processor() + self.add_module_config_arg() + self.add_module_input_arg() - if not argv[1:]: - self.help() - return False + if not argv[1:]: + self.help() + return False - self.args = self.parser.parse_args(argv[1:]) + self.args = self.parser.parse_args(argv[1:]) - config = self.get_config() - data = self.get_data() + config = self.get_config() + data = self.get_data() - try: - self.check_data(data) - except DataFormatError: - self.help() - return False - - results = self.module( - sign_name=self.module.default_signature, - data=data, - use_gpu=self.args.use_gpu, - batch_size=self.args.batch_size, - **config) + try: + self.check_data(data) + except DataFormatError: + self.help() + return False + + results = self.module( + sign_name=self.module.default_signature, + data=data, + use_gpu=self.args.use_gpu, + batch_size=self.args.batch_size, + **config) if six.PY2: try: diff --git a/paddlehub/module/module.py b/paddlehub/module/module.py index bd24532c2187d8d00129d308e6fd9e6cffd2eed8..b6b720f0108dcc15e325172659c6087b4fcbc75e 100644 --- a/paddlehub/module/module.py +++ b/paddlehub/module/module.py @@ -137,6 +137,7 @@ class Module(object): version=None): if not directory: return + self._code_version = "v2" self._directory = directory self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME) self._desc = module_desc_pb2.ModuleDesc() @@ -225,6 +226,14 @@ class Module(object): def name_prefix(self): return self._name_prefix + @property + def code_version(self): + return self._code_version + + @property + def is_runable(self): + return False + class ModuleHelper(object): def __init__(self, directory): @@ -252,6 +261,7 @@ class ModuleV1(Module): if not directory: return super(ModuleV1, self).__init__(name, directory, module_dir, version) + self._code_version = "v1" self.program = None self.assets = [] self.helper = None @@ -501,11 +511,9 @@ class ModuleV1(Module): if not self.processor: raise ValueError("This Module is not callable!") - def configs(self): - return self.processor.configs() - - def data_format(self, signature): - return self.processor.data_format(signature) + @property + def is_runable(self): + return self.default_signature != None def context(self, sign_name=None,