提交 10e69bd5 编写于 作者: W wuzewu

Fix bug

上级 b1e6b364
......@@ -74,7 +74,7 @@ class RunCommand(BaseCommand):
return hub.Module(directory=module_dir[0])
def add_module_config_arg(self):
configs = self.module.configs()
configs = self.module.processor.configs()
for config in configs:
if not config["dest"].startswith("--"):
config["dest"] = "--%s" % config["dest"]
......@@ -104,7 +104,7 @@ class RunCommand(BaseCommand):
def add_module_input_arg(self):
module_type = self.module.type.lower()
expect_data_format = self.module.data_format(
expect_data_format = self.module.processor.data_format(
self.module.default_signature)
self.arg_input_group.add_argument(
'--input_file',
......@@ -144,14 +144,14 @@ class RunCommand(BaseCommand):
if self.args.config:
yaml_config = yaml_parser.parse(self.args.config)
module_config = yaml_config.get("config", {})
for _config in self.module.configs():
for _config in self.module.processor.configs():
key = _config['dest']
module_config[key] = self.args.__dict__[key]
return module_config
def get_data(self):
module_type = self.module.type.lower()
expect_data_format = self.module.data_format(
expect_data_format = self.module.processor.data_format(
self.module.default_signature)
input_data = {}
if len(expect_data_format) == 1:
......@@ -176,7 +176,7 @@ class RunCommand(BaseCommand):
return input_data
def check_data(self, data):
expect_data_format = self.module.data_format(
expect_data_format = self.module.processor.data_format(
self.module.default_signature)
if len(data.keys()) != len(expect_data_format.keys()):
......@@ -236,35 +236,38 @@ class RunCommand(BaseCommand):
return False
# If the module is not executable, give an alarm and exit
if not self.module.default_signature:
if not self.module.is_runable:
print("ERROR! Module %s is not executable." % module_name)
return False
self.module.check_processor()
self.add_module_config_arg()
self.add_module_input_arg()
if self.module.code_version == "v2":
result = self.module(argv[1:])
else:
self.module.check_processor()
self.add_module_config_arg()
self.add_module_input_arg()
if not argv[1:]:
self.help()
return False
if not argv[1:]:
self.help()
return False
self.args = self.parser.parse_args(argv[1:])
self.args = self.parser.parse_args(argv[1:])
config = self.get_config()
data = self.get_data()
config = self.get_config()
data = self.get_data()
try:
self.check_data(data)
except DataFormatError:
self.help()
return False
results = self.module(
sign_name=self.module.default_signature,
data=data,
use_gpu=self.args.use_gpu,
batch_size=self.args.batch_size,
**config)
try:
self.check_data(data)
except DataFormatError:
self.help()
return False
results = self.module(
sign_name=self.module.default_signature,
data=data,
use_gpu=self.args.use_gpu,
batch_size=self.args.batch_size,
**config)
if six.PY2:
try:
......
......@@ -137,6 +137,7 @@ class Module(object):
version=None):
if not directory:
return
self._code_version = "v2"
self._directory = directory
self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
self._desc = module_desc_pb2.ModuleDesc()
......@@ -225,6 +226,14 @@ class Module(object):
def name_prefix(self):
return self._name_prefix
@property
def code_version(self):
return self._code_version
@property
def is_runable(self):
return False
class ModuleHelper(object):
def __init__(self, directory):
......@@ -252,6 +261,7 @@ class ModuleV1(Module):
if not directory:
return
super(ModuleV1, self).__init__(name, directory, module_dir, version)
self._code_version = "v1"
self.program = None
self.assets = []
self.helper = None
......@@ -501,11 +511,9 @@ class ModuleV1(Module):
if not self.processor:
raise ValueError("This Module is not callable!")
def configs(self):
return self.processor.configs()
def data_format(self, signature):
return self.processor.data_format(signature)
@property
def is_runable(self):
return self.default_signature != None
def context(self,
sign_name=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册