提交 10e69bd5 编写于 作者: W wuzewu

Fix bug

上级 b1e6b364
...@@ -74,7 +74,7 @@ class RunCommand(BaseCommand): ...@@ -74,7 +74,7 @@ class RunCommand(BaseCommand):
return hub.Module(directory=module_dir[0]) return hub.Module(directory=module_dir[0])
def add_module_config_arg(self): def add_module_config_arg(self):
configs = self.module.configs() configs = self.module.processor.configs()
for config in configs: for config in configs:
if not config["dest"].startswith("--"): if not config["dest"].startswith("--"):
config["dest"] = "--%s" % config["dest"] config["dest"] = "--%s" % config["dest"]
...@@ -104,7 +104,7 @@ class RunCommand(BaseCommand): ...@@ -104,7 +104,7 @@ class RunCommand(BaseCommand):
def add_module_input_arg(self): def add_module_input_arg(self):
module_type = self.module.type.lower() module_type = self.module.type.lower()
expect_data_format = self.module.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature) self.module.default_signature)
self.arg_input_group.add_argument( self.arg_input_group.add_argument(
'--input_file', '--input_file',
...@@ -144,14 +144,14 @@ class RunCommand(BaseCommand): ...@@ -144,14 +144,14 @@ class RunCommand(BaseCommand):
if self.args.config: if self.args.config:
yaml_config = yaml_parser.parse(self.args.config) yaml_config = yaml_parser.parse(self.args.config)
module_config = yaml_config.get("config", {}) module_config = yaml_config.get("config", {})
for _config in self.module.configs(): for _config in self.module.processor.configs():
key = _config['dest'] key = _config['dest']
module_config[key] = self.args.__dict__[key] module_config[key] = self.args.__dict__[key]
return module_config return module_config
def get_data(self): def get_data(self):
module_type = self.module.type.lower() module_type = self.module.type.lower()
expect_data_format = self.module.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature) self.module.default_signature)
input_data = {} input_data = {}
if len(expect_data_format) == 1: if len(expect_data_format) == 1:
...@@ -176,7 +176,7 @@ class RunCommand(BaseCommand): ...@@ -176,7 +176,7 @@ class RunCommand(BaseCommand):
return input_data return input_data
def check_data(self, data): def check_data(self, data):
expect_data_format = self.module.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature) self.module.default_signature)
if len(data.keys()) != len(expect_data_format.keys()): if len(data.keys()) != len(expect_data_format.keys()):
...@@ -236,35 +236,38 @@ class RunCommand(BaseCommand): ...@@ -236,35 +236,38 @@ class RunCommand(BaseCommand):
return False return False
# If the module is not executable, give an alarm and exit # If the module is not executable, give an alarm and exit
if not self.module.default_signature: if not self.module.is_runable:
print("ERROR! Module %s is not executable." % module_name) print("ERROR! Module %s is not executable." % module_name)
return False return False
self.module.check_processor() if self.module.code_version == "v2":
self.add_module_config_arg() result = self.module(argv[1:])
self.add_module_input_arg() else:
self.module.check_processor()
self.add_module_config_arg()
self.add_module_input_arg()
if not argv[1:]: if not argv[1:]:
self.help() self.help()
return False return False
self.args = self.parser.parse_args(argv[1:]) self.args = self.parser.parse_args(argv[1:])
config = self.get_config() config = self.get_config()
data = self.get_data() data = self.get_data()
try: try:
self.check_data(data) self.check_data(data)
except DataFormatError: except DataFormatError:
self.help() self.help()
return False return False
results = self.module( results = self.module(
sign_name=self.module.default_signature, sign_name=self.module.default_signature,
data=data, data=data,
use_gpu=self.args.use_gpu, use_gpu=self.args.use_gpu,
batch_size=self.args.batch_size, batch_size=self.args.batch_size,
**config) **config)
if six.PY2: if six.PY2:
try: try:
......
...@@ -137,6 +137,7 @@ class Module(object): ...@@ -137,6 +137,7 @@ class Module(object):
version=None): version=None):
if not directory: if not directory:
return return
self._code_version = "v2"
self._directory = directory self._directory = directory
self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME) self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
self._desc = module_desc_pb2.ModuleDesc() self._desc = module_desc_pb2.ModuleDesc()
...@@ -225,6 +226,14 @@ class Module(object): ...@@ -225,6 +226,14 @@ class Module(object):
def name_prefix(self): def name_prefix(self):
return self._name_prefix return self._name_prefix
@property
def code_version(self):
return self._code_version
@property
def is_runable(self):
return False
class ModuleHelper(object): class ModuleHelper(object):
def __init__(self, directory): def __init__(self, directory):
...@@ -252,6 +261,7 @@ class ModuleV1(Module): ...@@ -252,6 +261,7 @@ class ModuleV1(Module):
if not directory: if not directory:
return return
super(ModuleV1, self).__init__(name, directory, module_dir, version) super(ModuleV1, self).__init__(name, directory, module_dir, version)
self._code_version = "v1"
self.program = None self.program = None
self.assets = [] self.assets = []
self.helper = None self.helper = None
...@@ -501,11 +511,9 @@ class ModuleV1(Module): ...@@ -501,11 +511,9 @@ class ModuleV1(Module):
if not self.processor: if not self.processor:
raise ValueError("This Module is not callable!") raise ValueError("This Module is not callable!")
def configs(self): @property
return self.processor.configs() def is_runable(self):
return self.default_signature != None
def data_format(self, signature):
return self.processor.data_format(signature)
def context(self, def context(self,
sign_name=None, sign_name=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册