提交 6dd72f65 编写于 作者: F fary86

Add prim name to error message for nn_ops.py

上级 475f62f6
......@@ -117,10 +117,12 @@ class Validator:
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
excp_cls = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`'
f' with type `{type(arg_value).__name__}`.')
return arg_value
@staticmethod
......@@ -137,10 +139,11 @@ class Validator:
"""Method for checking whether an int value is in some range."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
excp_cls = TypeError if type_mismatch else ValueError
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
f' but got {arg_value}.')
raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
f' but got `{arg_value}` with type `{type(arg_value).__name__}`.')
return arg_value
@staticmethod
......@@ -192,19 +195,23 @@ class Validator:
@staticmethod
def check_const_input(arg_name, arg_value, prim_name):
"""Check valid value."""
"""Checks valid value."""
if arg_value is None:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
@staticmethod
def check_scalar_type_same(args, valid_values, prim_name):
"""check whether the types of inputs are the same."""
def check_type_same(args, valid_values, prim_name):
"""Checks whether the types of inputs are the same."""
def _check_tensor_type(arg):
arg_key, arg_val = arg
elem_type = arg_val
type_names = []
if not elem_type in valid_values:
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {elem_type}.')
for t in valid_values:
type_names.append(str(t))
types_info = '[' + ", ".join(type_names) + ']'
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},'
f' but got {elem_type}.')
return (arg_key, elem_type)
def _check_types_same(arg1, arg2):
......@@ -212,7 +219,7 @@ class Validator:
arg2_name, arg2_type = arg2
if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.')
return arg1
elem_types = map(_check_tensor_type, args.items())
......@@ -221,25 +228,8 @@ class Validator:
@staticmethod
def check_tensor_type_same(args, valid_values, prim_name):
"""Checks whether the element types of input tensors are the same."""
def _check_tensor_type(arg):
arg_key, arg_val = arg
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
elem_type = arg_val.element_type()
if not elem_type in valid_values:
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
f' but element type of `{arg_key}` is {elem_type}.')
return (arg_key, elem_type)
def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1
elem_types = map(_check_tensor_type, args.items())
reduce(_check_types_same, elem_types)
tensor_types = [mstype.tensor_type(t) for t in valid_values]
Validator.check_type_same(args, tensor_types, prim_name)
@staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
......
......@@ -34,7 +34,7 @@ GRAPH_MODE = 0
PYNATIVE_MODE = 1
def _make_directory(path: str):
def _make_directory(path):
"""Make directory."""
real_path = None
if path is None or not isinstance(path, str) or path.strip() == "":
......
此差异已折叠。
......@@ -41,7 +41,7 @@ class TestInputs:
dr.piecewise_constant_lr(milestone1, learning_rates)
milestone2 = [1.0, 2.0, True]
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.piecewise_constant_lr(milestone2, learning_rates)
def test_learning_rates1(self):
......@@ -92,13 +92,13 @@ class TestInputs:
def test_total_step1(self):
total_step1 = 2.0
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power)
def test_total_step2(self):
......@@ -114,13 +114,13 @@ class TestInputs:
def test_step_per_epoch1(self):
step_per_epoch1 = True
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power)
def test_step_per_epoch2(self):
......@@ -136,13 +136,13 @@ class TestInputs:
def test_decay_epoch1(self):
decay_epoch1 = 'm'
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power)
def test_decay_epoch2(self):
......
......@@ -60,7 +60,7 @@ def test_ssim_max_val_zero():
net = SSIMNet(max_val)
def test_ssim_filter_size_float():
with pytest.raises(ValueError):
with pytest.raises(TypeError):
net = SSIMNet(filter_size=1.1)
def test_ssim_filter_size_zero():
......
......@@ -516,7 +516,7 @@ test_cases = [
test_cases_for_verify_exception = [
('Conv2d_ValueError_1', {
'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': ValueError}),
'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': TypeError}),
'desc_inputs': [0],
}),
('Conv2d_ValueError_2', {
......@@ -528,7 +528,7 @@ test_cases_for_verify_exception = [
'desc_inputs': [0],
}),
('MaxPoolWithArgmax_ValueError_2', {
'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {'exception': ValueError}),
'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {'exception': TypeError}),
'desc_inputs': [0],
}),
('MaxPoolWithArgmax_ValueError_3', {
......@@ -540,7 +540,7 @@ test_cases_for_verify_exception = [
'desc_inputs': [0],
}),
('FusedBatchNorm_ValueError_1', {
'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': ValueError}),
'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': TypeError}),
'desc_inputs': [0],
}),
('FusedBatchNorm_ValueError_2', {
......@@ -560,31 +560,31 @@ test_cases_for_verify_exception = [
'desc_inputs': [0],
}),
('Softmax_ValueError_1', {
'block': (lambda _: P.Softmax("1"), {'exception': ValueError}),
'block': (lambda _: P.Softmax("1"), {'exception': TypeError}),
'desc_inputs': [0],
}),
('Softmax_ValueError_2', {
'block': (lambda _: P.Softmax(1.1), {'exception': ValueError}),
'block': (lambda _: P.Softmax(1.1), {'exception': TypeError}),
'desc_inputs': [0],
}),
('Softmax_ValueError_3', {
'block': (lambda _: P.Softmax(axis="1"), {'exception': ValueError}),
'block': (lambda _: P.Softmax(axis="1"), {'exception': TypeError}),
'desc_inputs': [0],
}),
('DropoutGenMask_ValueError_1', {
'block': (lambda _: P.DropoutGenMask(Seed0="seed0"), {'exception': ValueError}),
'block': (lambda _: P.DropoutGenMask(Seed0="seed0"), {'exception': TypeError}),
'desc_inputs': [0],
}),
('DropoutGenMask_ValueError_2', {
'block': (lambda _: P.DropoutGenMask(Seed0=1.0), {'exception': ValueError}),
'block': (lambda _: P.DropoutGenMask(Seed0=1.0), {'exception': TypeError}),
'desc_inputs': [0],
}),
('DropoutGenMask_ValueError_3', {
'block': (lambda _: P.DropoutGenMask(Seed1="seed1"), {'exception': ValueError}),
'block': (lambda _: P.DropoutGenMask(Seed1="seed1"), {'exception': TypeError}),
'desc_inputs': [0],
}),
('DropoutGenMask_ValueError_4', {
'block': (lambda _: P.DropoutGenMask(Seed1=2.0), {'exception': ValueError}),
'block': (lambda _: P.DropoutGenMask(Seed1=2.0), {'exception': TypeError}),
'desc_inputs': [0],
}),
('MaxPool2d_ValueError_1', {
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册