提交 2a93e018 编写于 作者: W wuzewu

replace asserts

上级 5b09fad3
...@@ -30,8 +30,10 @@ class BaseCommand: ...@@ -30,8 +30,10 @@ class BaseCommand:
def instance(cls): def instance(cls):
if cls.name in BaseCommand.command_dict: if cls.name in BaseCommand.command_dict:
command = BaseCommand.command_dict[cls.name] command = BaseCommand.command_dict[cls.name]
assert command.__class__.__name__ == cls.__name__, "already has a command %s with type %s" % ( if command.__class__.__name__ != cls.__name__:
cls.name, command.__class__) raise KeyError(
"Command dict already has a command %s with type %s" %
(cls.name, command.__class__))
return command return command
if not hasattr(cls, '_instance'): if not hasattr(cls, '_instance'):
cls._instance = cls(cls.name) cls._instance = cls(cls.name)
...@@ -39,9 +41,8 @@ class BaseCommand: ...@@ -39,9 +41,8 @@ class BaseCommand:
return cls._instance return cls._instance
def __init__(self, name): def __init__(self, name):
assert not hasattr( if hasattr(self.__class__, '_instance'):
self.__class__, raise RuntimeError("Please use `instance()` to get Command object!")
'_instance'), 'Please use `instance()` to get Command object!'
self.args = None self.args = None
self.name = name self.name = name
self.show_in_help = True self.show_in_help = True
......
...@@ -122,7 +122,10 @@ class RunCommand(BaseCommand): ...@@ -122,7 +122,10 @@ class RunCommand(BaseCommand):
# data_format check # data_format check
if not self.args.config: if not self.args.config:
assert len(expect_data_format) == 1 if len(expect_data_format) != 1:
raise RuntimeError(
"Module requires %d inputs, please use config file to specify mappings for data and inputs."
% len(expect_data_format))
origin_data_key = list(origin_data.keys())[0] origin_data_key = list(origin_data.keys())[0]
input_data_key = list(expect_data_format.keys())[0] input_data_key = list(expect_data_format.keys())[0]
input_data = {input_data_key: origin_data[origin_data_key]} input_data = {input_data_key: origin_data[origin_data_key]}
...@@ -135,11 +138,22 @@ class RunCommand(BaseCommand): ...@@ -135,11 +138,22 @@ class RunCommand(BaseCommand):
input_data = {input_data_key: origin_data[origin_data_key]} input_data = {input_data_key: origin_data[origin_data_key]}
else: else:
input_data_format = yaml_config['input_data'] input_data_format = yaml_config['input_data']
assert len(input_data_format) == len(expect_data_format) if len(input_data_format) != len(expect_data_format):
raise ValueError(
"Module requires %d inputs, but the input file gives %d."
% (len(expect_data_format), len(input_data_format)))
for key, value in expect_data_format.items(): for key, value in expect_data_format.items():
assert key in input_data_format if key not in input_data_format:
assert value['type'] == hub.DataType.type( raise KeyError(
input_data_format[key]['type']) "Input file gives an unexpected input %s" % key)
if value['type'] != hub.DataType.type(
input_data_format[key]['type']):
raise TypeError(
"Module expect Type %s for %s, but the input file gives %s"
% (value['type'], key,
hub.DataType.type(
input_data_format[key]['type'])))
input_data = {} input_data = {}
for key, value in yaml_config['input_data'].items(): for key, value in yaml_config['input_data'].items():
......
...@@ -26,9 +26,9 @@ from paddlehub.common.logger import logger ...@@ -26,9 +26,9 @@ from paddlehub.common.logger import logger
def get_variable_info(var): def get_variable_info(var):
assert isinstance( if not isinstance(var, fluid.framework.Variable):
var, raise TypeError("var shoule be an instance of fluid.framework.Variable")
fluid.framework.Variable), "var should be a fluid.framework.Variable"
var_info = { var_info = {
'type': var.type, 'type': var.type,
'name': var.name, 'name': var.name,
...@@ -148,20 +148,26 @@ def connect_program(pre_program, next_program, input_dict=None, inplace=True): ...@@ -148,20 +148,26 @@ def connect_program(pre_program, next_program, input_dict=None, inplace=True):
} }
to_block.append_op(**op_info) to_block.append_op(**op_info)
assert isinstance(pre_program, if not isinstance(pre_program, fluid.Program):
fluid.Program), "pre_program should be fluid.Program" raise TypeError("pre_program shoule be an instance of fluid.Program")
assert isinstance(next_program,
fluid.Program), "next_program should be fluid.Program" if not isinstance(next_program, fluid.Program):
raise TypeError("next_program shoule be an instance of fluid.Program")
output_program = pre_program if inplace else pre_program.clone( output_program = pre_program if inplace else pre_program.clone(
for_test=False) for_test=False)
if input_dict: if input_dict:
assert isinstance( if not isinstance(input_dict, dict):
input_dict, raise TypeError(
dict), "the input_dict should be a dict with string-Variable pair" "input_dict shoule be a python dict like {str:fluid.framework.Variable}"
)
for key, var in input_dict.items(): for key, var in input_dict.items():
assert isinstance( if not isinstance(var, fluid.framework.Variable):
var, fluid.framework.Variable raise TypeError(
), "the input_dict should be a dict with string-Variable pair" "input_dict shoule be a python dict like {str:fluid.framework.Variable}"
)
var_info = copy.deepcopy(get_variable_info(var)) var_info = copy.deepcopy(get_variable_info(var))
input_var = output_program.global_block().create_var(**var_info) input_var = output_program.global_block().create_var(**var_info)
output_var = next_program.global_block().var(key) output_var = next_program.global_block().var(key)
......
...@@ -117,9 +117,10 @@ class Module(object): ...@@ -117,9 +117,10 @@ class Module(object):
self._init_with_module_file(module_dir=module_dir) self._init_with_module_file(module_dir=module_dir)
elif signatures: elif signatures:
if processor: if processor:
assert issubclass( if not issubclass(processor, BaseProcessor):
processor, BaseProcessor raise TypeError(
), "processor should be sub class of hub.BaseProcessor" "processor shoule be an instance of paddlehub.BaseProcessor"
)
if assets: if assets:
self.assets = utils.to_list(assets) self.assets = utils.to_list(assets)
# for asset in assets: # for asset in assets:
...@@ -446,7 +447,8 @@ class Module(object): ...@@ -446,7 +447,8 @@ class Module(object):
return result return result
def check_processor(self): def check_processor(self):
assert self.processor, "this module couldn't be call" if not self.processor:
raise ValueError("This Module is not callable!")
def context(self, def context(self,
sign_name, sign_name,
...@@ -461,7 +463,9 @@ class Module(object): ...@@ -461,7 +463,9 @@ class Module(object):
available for BERT/ERNIE module available for BERT/ERNIE module
""" """
assert sign_name in self.signatures, "module did not have a signature with name %s" % sign_name if sign_name not in self.signatures:
raise KeyError(
"Module did not have a signature with name %s" % sign_name)
signature = self.signatures[sign_name] signature = self.signatures[sign_name]
program = self.program.clone(for_test=for_test) program = self.program.clone(for_test=for_test)
...@@ -535,19 +539,28 @@ class Module(object): ...@@ -535,19 +539,28 @@ class Module(object):
return self.get_name_prefix() + var_name return self.get_name_prefix() + var_name
def _check_signatures(self): def _check_signatures(self):
assert self.signatures, "Signature array should not be None" if not self.signatures:
raise ValueError("Signatures should not be None")
for key, sign in self.signatures.items(): for key, sign in self.signatures.items():
assert isinstance(sign, if not isinstance(sign, Signature):
Signature), "sign_arr should be list of Signature" raise TypeError(
"Item in Signatures shoule be an instance of paddlehub.Signature"
)
for input in sign.inputs: for input in sign.inputs:
_tmp_program = input.block.program _tmp_program = input.block.program
assert self.program == _tmp_program, "all the variable should come from the same program" if not self.program == _tmp_program:
raise ValueError(
"All input and outputs variables in signature should come from the same Program"
)
for output in sign.outputs: for output in sign.outputs:
_tmp_program = output.block.program _tmp_program = output.block.program
assert self.program == _tmp_program, "all the variable should come from the same program" if not self.program == _tmp_program:
raise ValueError(
"All input and outputs variables in signature should come from the same Program"
)
def serialize_to_path(self, path=None, exe=None): def serialize_to_path(self, path=None, exe=None):
self._check_signatures() self._check_signatures()
......
...@@ -35,23 +35,29 @@ class Signature: ...@@ -35,23 +35,29 @@ class Signature:
if not feed_names: if not feed_names:
feed_names = [""] * len(inputs) feed_names = [""] * len(inputs)
feed_names = to_list(feed_names) feed_names = to_list(feed_names)
assert len(inputs) == len( if len(inputs) != len(feed_names):
feed_names), "the length of feed_names must be same with inputs" raise ValueError(
"The length of feed_names must be same with inputs")
if not fetch_names: if not fetch_names:
fetch_names = [""] * len(outputs) fetch_names = [""] * len(outputs)
fetch_names = to_list(fetch_names) fetch_names = to_list(fetch_names)
assert len(outputs) == len( if len(outputs) != len(fetch_names):
fetch_names), "the length of fetch_names must be same with outputs" raise ValueError(
"the length of fetch_names must be same with outputs")
self.name = name self.name = name
for item in inputs: for item in inputs:
assert isinstance( if not isinstance(item, Variable):
item, Variable), "the item of inputs list shoule be Variable" raise TypeError(
"Item in inputs list shoule be an instance of fluid.framework.Variable"
)
for item in outputs: for item in outputs:
assert isinstance( if not isinstance(item, Variable):
item, Variable), "the item of outputs list shoule be Variable" raise TypeError(
"Item in outputs list shoule be an instance of fluid.framework.Variable"
)
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
......
...@@ -300,7 +300,8 @@ class SequenceLabelReader(BaseReader): ...@@ -300,7 +300,8 @@ class SequenceLabelReader(BaseReader):
return return_list return return_list
def _reseg_token_label(self, tokens, labels, tokenizer): def _reseg_token_label(self, tokens, labels, tokenizer):
assert len(tokens) == len(labels) if len(tokens) != len(labels):
raise ValueError("The length of tokens must be same with labels")
ret_tokens = [] ret_tokens = []
ret_labels = [] ret_labels = []
for token, label in zip(tokens, labels): for token, label in zip(tokens, labels):
...@@ -316,7 +317,8 @@ class SequenceLabelReader(BaseReader): ...@@ -316,7 +317,8 @@ class SequenceLabelReader(BaseReader):
sub_label = "I-" + label[2:] sub_label = "I-" + label[2:]
ret_labels.extend([sub_label] * (len(sub_token) - 1)) ret_labels.extend([sub_label] * (len(sub_token) - 1))
assert len(ret_tokens) == len(ret_labels) if len(ret_tokens) != len(labels):
raise ValueError("The length of ret_tokens can't match with labels")
return ret_tokens, ret_labels return ret_tokens, ret_labels
def _convert_example_to_record(self, example, max_seq_length, tokenizer): def _convert_example_to_record(self, example, max_seq_length, tokenizer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册