提交 5add5979 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!92 add prim name to param check error message for math_ops.py

Merge pull request !92 from fary86/add-prim-name-for-param-validator
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Check parameters.""" """Check parameters."""
import re import re
from enum import Enum from enum import Enum
from functools import reduce
from itertools import repeat from itertools import repeat
from collections import Iterable from collections import Iterable
...@@ -93,8 +94,131 @@ rel_strs = { ...@@ -93,8 +94,131 @@ rel_strs = {
} }
class Validator:
"""validator for checking input parameters"""
@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None):
"""
Method for judging relation between two int values or list/tuple made up of ints.
This method is not suitable for judging relation between floats, since it does not consider float error.
"""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
msg_prefix = f'For {prim_name} the' if prim_name else "The"
raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
@staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},'
f' but got {arg_value}.')
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
"""Method for checking whether an int value is in some range."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},'
f' but got {arg_value}.')
return arg_value
@staticmethod
def check_subclass(arg_name, type_, template_type, prim_name):
"""Check whether some type is sublcass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_tensor_type_same(args, valid_values, prim_name):
"""check whether the element types of input tensors are the same."""
def _check_tensor_type(arg):
arg_key, arg_val = arg
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
elem_type = arg_val.element_type()
if not elem_type in valid_values:
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {elem_type}.')
return (arg_key, elem_type)
def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1
elem_types = map(_check_tensor_type, args.items())
reduce(_check_types_same, elem_types)
@staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name):
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
def _check_argument_type(arg):
arg_key, arg_val = arg
if isinstance(arg_val, type(mstype.tensor)):
arg_val = arg_val.element_type()
if not arg_val in valid_values:
raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {arg_val}.')
return arg
def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
excp_flag = False
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
arg1_type = arg1_type.element_type()
arg2_type = arg2_type.element_type()
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
pass
else:
excp_flag = True
if excp_flag or arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1
reduce(_check_types_same, map(_check_argument_type, args.items()))
@staticmethod
def check_value_type(arg_name, arg_value, valid_types, prim_name):
"""Check whether a values is instance of some types."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be '
f'{"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
# `check_value_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
class ParamValidator: class ParamValidator:
"""Parameter validator.""" """Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@staticmethod @staticmethod
def equal(arg_name, arg_value, cond_str, cond): def equal(arg_name, arg_value, cond_str, cond):
......
...@@ -16,13 +16,14 @@ ...@@ -16,13 +16,14 @@
"""broadcast""" """broadcast"""
def _get_broadcast_shape(x_shape, y_shape): def _get_broadcast_shape(x_shape, y_shape, prim_name):
""" """
Doing broadcast between tensor x and tensor y. Doing broadcast between tensor x and tensor y.
Args: Args:
x_shape (list): The shape of tensor x. x_shape (list): The shape of tensor x.
y_shape (list): The shape of tensor y. y_shape (list): The shape of tensor y.
prim_name (str): Primitive name.
Returns: Returns:
List, the shape that broadcast between tensor x and tensor y. List, the shape that broadcast between tensor x and tensor y.
...@@ -50,7 +51,8 @@ def _get_broadcast_shape(x_shape, y_shape): ...@@ -50,7 +51,8 @@ def _get_broadcast_shape(x_shape, y_shape):
elif x_shape[i] == y_shape[i]: elif x_shape[i] == y_shape[i]:
broadcast_shape_back.append(x_shape[i]) broadcast_shape_back.append(x_shape[i])
else: else:
raise ValueError("The x_shape {} and y_shape {} can not broadcast.".format(x_shape, y_shape)) raise ValueError("For '{}' the x_shape {} and y_shape {} can not broadcast.".format(
prim_name, x_shape, y_shape))
broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length]
broadcast_shape = broadcast_shape_front + broadcast_shape_back broadcast_shape = broadcast_shape_front + broadcast_shape_back
......
...@@ -28,9 +28,16 @@ from ..._checkparam import ParamValidator as validator ...@@ -28,9 +28,16 @@ from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import Tensor from ...common.tensor import Tensor
from ..operations.math_ops import _check_infer_attr_reduce, _infer_shape_reduce from ..operations.math_ops import _infer_shape_reduce
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
def _check_infer_attr_reduce(axis, keep_dims):
validator.check_type('keep_dims', keep_dims, [bool])
validator.check_type('axis', axis, [int, tuple])
if isinstance(axis, tuple):
for index, value in enumerate(axis):
validator.check_type('axis[%d]' % index, value, [int])
class ExpandDims(PrimitiveWithInfer): class ExpandDims(PrimitiveWithInfer):
""" """
...@@ -1090,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer): ...@@ -1090,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
...@@ -1136,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer): ...@@ -1136,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
......
...@@ -194,6 +194,9 @@ class PrimitiveWithInfer(Primitive): ...@@ -194,6 +194,9 @@ class PrimitiveWithInfer(Primitive):
Primitive.__init__(self, name) Primitive.__init__(self, name)
self.set_prim_type(prim_type.py_infer_shape) self.set_prim_type(prim_type.py_infer_shape)
def prim_name(self):
return self.__class__.__name__
def _clone(self): def _clone(self):
""" """
Deeply clones the primitive object. Deeply clones the primitive object.
......
...@@ -23,20 +23,25 @@ from ...utils import keyword ...@@ -23,20 +23,25 @@ from ...utils import keyword
class CheckExceptionsEC(IExectorComponent): class CheckExceptionsEC(IExectorComponent):
""" """
Check if the function raises the expected Exception. Check if the function raises the expected Exception and the error message contains specified keywords if not None.
Examples: Examples:
{ {
'block': f, 'block': f,
'exception': Exception 'exception': Exception,
'error_keywords': ['TensorAdd', 'shape']
} }
""" """
def run_function(self, function, inputs, verification_set): def run_function(self, function, inputs, verification_set):
f = function[keyword.block] f = function[keyword.block]
args = inputs[keyword.desc_inputs] args = inputs[keyword.desc_inputs]
e = function.get(keyword.exception, Exception) e = function.get(keyword.exception, Exception)
error_kws = function.get(keyword.error_keywords, None)
try: try:
with pytest.raises(e): with pytest.raises(e) as exec_info:
f(*args) f(*args)
except: except:
raise Exception(f"Expect {e}, but got {sys.exc_info()[0]}") raise Exception(f"Expect {e}, but got {sys.exc_info()[0]}")
if error_kws and any(keyword not in str(exec_info.value) for keyword in error_kws):
raise ValueError('Error message `{}` does not contain all keywords `{}`'.format(
str(exec_info.value), error_kws))
...@@ -87,8 +87,9 @@ def get_function_config(function): ...@@ -87,8 +87,9 @@ def get_function_config(function):
init_param_with = function.get(keyword.init_param_with, None) init_param_with = function.get(keyword.init_param_with, None)
split_outputs = function.get(keyword.split_outputs, True) split_outputs = function.get(keyword.split_outputs, True)
exception = function.get(keyword.exception, Exception) exception = function.get(keyword.exception, Exception)
error_keywords = function.get(keyword.error_keywords, None)
return delta, max_error, input_selector, output_selector, sampling_times, \ return delta, max_error, input_selector, output_selector, sampling_times, \
reduce_output, init_param_with, split_outputs, exception reduce_output, init_param_with, split_outputs, exception, error_keywords
def get_grad_checking_options(function, inputs): def get_grad_checking_options(function, inputs):
""" """
...@@ -104,6 +105,6 @@ def get_grad_checking_options(function, inputs): ...@@ -104,6 +105,6 @@ def get_grad_checking_options(function, inputs):
""" """
f = function[keyword.block] f = function[keyword.block]
args = inputs[keyword.desc_inputs] args = inputs[keyword.desc_inputs]
delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _ = \ delta, max_error, input_selector, output_selector, sampling_times, reduce_output, _, _, _, _ = \
get_function_config(function) get_function_config(function)
return f, args, delta, max_error, input_selector, output_selector, sampling_times, reduce_output return f, args, delta, max_error, input_selector, output_selector, sampling_times, reduce_output
...@@ -54,11 +54,12 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex ...@@ -54,11 +54,12 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
block = block_config block = block_config
delta, max_error, input_selector, output_selector, \ delta, max_error, input_selector, output_selector, \
sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config({}) sampling_times, reduce_output, init_param_with, split_outputs, exception, error_keywords = get_function_config({})
if isinstance(block_config, tuple) and isinstance(block_config[-1], dict): if isinstance(block_config, tuple) and isinstance(block_config[-1], dict):
block = block_config[0] block = block_config[0]
delta, max_error, input_selector, output_selector, \ delta, max_error, input_selector, output_selector, \
sampling_times, reduce_output, init_param_with, split_outputs, exception = get_function_config(block_config[-1]) sampling_times, reduce_output, init_param_with, \
split_outputs, exception, error_keywords = get_function_config(block_config[-1])
if block: if block:
func_list.append({ func_list.append({
...@@ -78,7 +79,8 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex ...@@ -78,7 +79,8 @@ def fill_block_config(ret, block_config, tid, group, desc_inputs, desc_bprop, ex
keyword.const_first: const_first, keyword.const_first: const_first,
keyword.add_fake_input: add_fake_input, keyword.add_fake_input: add_fake_input,
keyword.split_outputs: split_outputs, keyword.split_outputs: split_outputs,
keyword.exception: exception keyword.exception: exception,
keyword.error_keywords: error_keywords
}) })
if desc_inputs or desc_const: if desc_inputs or desc_const:
......
...@@ -73,5 +73,6 @@ keyword.const_first = "const_first" ...@@ -73,5 +73,6 @@ keyword.const_first = "const_first"
keyword.add_fake_input = "add_fake_input" keyword.add_fake_input = "add_fake_input"
keyword.fake_input_type = "fake_input_type" keyword.fake_input_type = "fake_input_type"
keyword.exception = "exception" keyword.exception = "exception"
keyword.error_keywords = "error_keywords"
sys.modules[__name__] = keyword sys.modules[__name__] = keyword
...@@ -234,7 +234,7 @@ raise_set = [ ...@@ -234,7 +234,7 @@ raise_set = [
'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}), 'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}),
('ReduceSum_Error', { ('ReduceSum_Error', {
'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': ValueError}), 'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': TypeError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}),
] ]
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册