# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Check parameters.""" import re import inspect import math from enum import Enum from functools import reduce, wraps from itertools import repeat from collections.abc import Iterable import numpy as np from mindspore import log as logger from .common import dtype as mstype # Named string regular expression _name_re = r"^\w+[0-9a-zA-Z\_\.]*$" class Rel(Enum): """Numerical relationship between variables, logical relationship enumeration definition of range.""" # scalar compare EQ = 1 # == NE = 2 # != LT = 3 # < LE = 4 # <= GT = 5 # > GE = 6 # >= # scalar range check INC_NEITHER = 7 # (), include neither INC_LEFT = 8 # [), include left INC_RIGHT = 9 # (], include right INC_BOTH = 10 # [], include both # collection in, not in IN = 11 NOT_IN = 12 @staticmethod def get_strs(rel): """Get value from rel_strs.""" return rel_strs.get(rel, "") @staticmethod def get_fns(rel): """Get value from rel_fns.""" return rel_fns.get(rel, lambda *args: False) rel_fns = { # scalar compare Rel.EQ: lambda x, y: x == y, Rel.NE: lambda x, y: x != y, Rel.LT: lambda x, y: x < y, Rel.LE: lambda x, y: x <= y, Rel.GT: lambda x, y: x > y, Rel.GE: lambda x, y: x >= y, # scalar range check Rel.INC_NEITHER: lambda x, lower, upper: (lower < x < upper), Rel.INC_LEFT: lambda x, lower, upper: (lower <= x < upper), Rel.INC_RIGHT: lambda x, lower, upper: (lower < x <= upper), Rel.INC_BOTH: lambda x, lower, upper: (lower <= x <= upper), # collection in, not in Rel.IN: lambda x, y: x in y, Rel.NOT_IN: lambda x, y: x not in y, } rel_strs = { # scalar compare Rel.EQ: "equal to {}", Rel.NE: "not equal to {}", Rel.LT: "less than {}", Rel.LE: "less or equal to {}", Rel.GT: "greater than {}", Rel.GE: "greater or equal to {}", # scalar range check Rel.INC_NEITHER: "({}, {})", Rel.INC_LEFT: "[{}, {})", Rel.INC_RIGHT: "({}, {}]", Rel.INC_BOTH: "[{}, {}]", # collection in, not in Rel.IN: "in {}", Rel.NOT_IN: "not in {}", } class Validator: """validator for checking input parameters""" @staticmethod def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): """ Method for judging relation between two int values or list/tuple made up of ints. This method is not suitable for judging relation between floats, since it does not consider float error. """ rel_fn = Rel.get_fns(rel) if not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') @staticmethod def check_integer(arg_name, arg_value, value, rel, prim_name): """Integer value judgment.""" rel_fn = Rel.get_fns(rel) type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) excp_cls = TypeError if type_mismatch else ValueError if type_mismatch or not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(value) msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`' f' with type `{type(arg_value).__name__}`.') return arg_value @staticmethod def check_number(arg_name, arg_value, value, rel, prim_name): """Integer value judgment.""" rel_fn = Rel.get_fns(rel) if not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(value) raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') return arg_value @staticmethod def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): """Method for checking whether an int value is in some range.""" rel_fn = Rel.get_fns(rel) type_mismatch = not isinstance(arg_value, int) excp_cls = TypeError if type_mismatch else ValueError if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' f' but got `{arg_value}` with type `{type(arg_value).__name__}`.') return arg_value @staticmethod def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): """Method for checking whether a numeric value is in some range.""" rel_fn = Rel.get_fns(rel) if not rel_fn(arg_value, lower_limit, upper_limit): rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.') return arg_value @staticmethod def check_string(arg_name, arg_value, valid_values, prim_name): """Checks whether a string is in some value list""" if isinstance(arg_value, str) and arg_value in valid_values: return arg_value if len(valid_values) == 1: raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},' f' but got {arg_value}.') raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},' f' but got {arg_value}.') @staticmethod def check_pad_value_by_mode(pad_mode, padding, prim_name): """Validates value of padding according to pad_mode""" if pad_mode != 'pad' and padding != 0: raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.") return padding @staticmethod def check_float_positive(arg_name, arg_value, prim_name): """Float type judgment.""" msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" if isinstance(arg_value, float): if arg_value > 0: return arg_value raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.") raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") @staticmethod def check_subclass(arg_name, type_, template_type, prim_name): """Checks whether some type is subclass of another type""" if not isinstance(template_type, Iterable): template_type = (template_type,) if not any([mstype.issubclass_(type_, x) for x in template_type]): type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_) raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') @staticmethod def check_const_input(arg_name, arg_value, prim_name): """Checks valid value.""" if arg_value is None: raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') @staticmethod def check_type_same(args, valid_values, prim_name): """Checks whether the types of inputs are the same.""" def _check_tensor_type(arg): arg_key, arg_val = arg elem_type = arg_val if not elem_type in valid_values: type_names = [] for t in valid_values: type_names.append(str(t)) types_info = '[' + ', '.join(type_names) + ']' raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},' f' but got {elem_type}.') return (arg_key, elem_type) def _check_types_same(arg1, arg2): arg1_name, arg1_type = arg1 arg2_name, arg2_type = arg2 if arg1_type != arg2_type: raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.') return arg1 elem_types = map(_check_tensor_type, args.items()) reduce(_check_types_same, elem_types) @staticmethod def check_tensor_type_same(args, valid_values, prim_name): """Checks whether the element types of input tensors are the same.""" tensor_types = [mstype.tensor_type(t) for t in valid_values] Validator.check_type_same(args, tensor_types, prim_name) @staticmethod def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): """ Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. """ def _check_argument_type(arg): arg_key, arg_val = arg if isinstance(arg_val, type(mstype.tensor)): arg_val = arg_val.element_type() if not arg_val in valid_values: raise TypeError(f'For \'{prim_name}\' the `{arg_key}` should be in {valid_values},' f' but `{arg_key}` is {arg_val}.') return arg def _check_types_same(arg1, arg2): arg1_name, arg1_type = arg1 arg2_name, arg2_type = arg2 except_flag = False if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)): arg1_type = arg1_type.element_type() arg2_type = arg2_type.element_type() elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): pass elif allow_mix: arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type else: except_flag = True if except_flag or arg1_type != arg2_type: raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') return arg1 reduce(_check_types_same, map(_check_argument_type, args.items())) @staticmethod def check_value_type(arg_name, arg_value, valid_types, prim_name): """Checks whether a value is instance of some types.""" valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) def raise_error_msg(): """func for raising error message when check failed""" type_names = [t.__name__ for t in valid_types] num_types = len(valid_types) msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and # `check_value_type('x', True, [bool, int])` will check pass if isinstance(arg_value, bool) and bool not in tuple(valid_types): raise_error_msg() if isinstance(arg_value, tuple(valid_types)): return arg_value raise_error_msg() @staticmethod def check_type_name(arg_name, arg_type, valid_types, prim_name): """Checks whether a type in some specified types""" valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) def get_typename(t): return t.__name__ if hasattr(t, '__name__') else str(t) if arg_type in valid_types: return arg_type type_names = [get_typename(t) for t in valid_types] msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' if len(valid_types) == 1: raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' f' but got {get_typename(arg_type)}.') raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' f' but got {get_typename(arg_type)}.') @staticmethod def check_float_legal_value(arg_name, arg_value, prim_name): """Checks whether a legal value of float type""" msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" if isinstance(arg_value, float): if math.isinf(arg_value) or math.isnan(arg_value): raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.") return arg_value raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") def check_int(input_param): """Int type judgment.""" if isinstance(input_param, int) and not isinstance(input_param, bool): return input_param raise TypeError("Input type must be int!") def check_int_positive(input_param): """Int type judgment.""" if isinstance(input_param, bool): raise TypeError("Input type must be int cannot be bool!") if isinstance(input_param, int): if input_param > 0: return input_param raise ValueError("The input_param must be positive, but got input_param {}.".format(input_param)) raise TypeError("Input type must be int cannot be {}!".format(type(input_param))) def check_int_non_negative(input_param): """Non_negative type judgment.""" if isinstance(input_param, bool): raise TypeError("Input type must be int cannot be bool!") if isinstance(input_param, int): if input_param >= 0: return input_param raise ValueError("The input_param must be non_negative, but got input_param {}.".format(input_param)) raise TypeError("Input type must be int cannot be {}!".format(type(input_param))) def check_int_zero_one(input_param): """Judge whether it is 0 or 1.""" if input_param in (0, 1): return input_param raise ValueError("The data must be 0 or 1.") def check_bool(input_param): """Bool type judgment.""" if isinstance(input_param, bool): return input_param raise TypeError("Input type must be bool!") def check_input_format(input_param): """Judge input format.""" if input_param == "NCHW": return input_param raise ValueError("The data format must be NCHW.") def check_padding(padding): """Check padding.""" if padding >= 0: return padding raise ValueError("The padding must be at least 0,"" but got padding {}.".format(padding)) def check_padmode(mode): """Check padmode.""" if mode in ("same", "valid", "pad"): return mode raise ValueError("The pad mode must be same or valid or pad,"" but got mode {}.".format(mode)) def check_tensor_supported_type(dtype): """Check tensor dtype.""" if dtype in (mstype.int32, mstype.float32): return dtype raise ValueError("The dtype must be mstype.int32 or mstype.float32, but got mstype {}.".format(dtype)) def _expand_tuple(n_dimensions): """To expand a number to tuple.""" def convert(m): if not isinstance(m, tuple): if isinstance(m, int): return tuple(repeat(m, n_dimensions)) raise TypeError("Input type must be int or tuple.") if not len(m) is n_dimensions: raise TypeError("Input dimension is incorrect.") for i in m: if not isinstance(i, int): raise TypeError("Incorrect type inside of a tuple!") return m return convert def check_input_data(*data, data_class): """Input data check.""" for item in data: if isinstance(item, (list, tuple)): for v in item: check_input_data(v, data_class=data_class) else: if not isinstance(item, data_class): raise ValueError(f'Please provide as model inputs' f' either a single' f' or a list of {data_class.__name__},' f' but got part data type is {str(type(item))}.') if item.size() == 0: msg = "Please provide non-empty data." logger.error(msg) raise ValueError(msg) def check_output_data(data): """Output data check.""" if not data: raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.') once = _expand_tuple(1) twice = _expand_tuple(2) triple = _expand_tuple(3) valid_data_types = (int, float, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.float32, np.float64, bool, np.bool_) def check_type(arg_name, arg_value, valid_types): """Check value type.""" # if input type is Tensor ,get element type if isinstance(arg_value, type(mstype.tensor)): arg_value = arg_value.element_type() # First, check if arg_value has argvalid_types if isinstance(arg_value, tuple(valid_types)): return type(arg_value).__name__ # Second, wrap arg_value with numpy array so that it can be checked through numpy api if isinstance(arg_value, (list, tuple)): arg_value = np.array(arg_value) # Thirdly, check the data type by numpy's dtype api valid = False if isinstance(arg_value, np.ndarray): valid = arg_value.dtype in valid_data_types # Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and # `check_type('x', True, [bool, int])` will check pass if isinstance(arg_value, bool) and bool not in tuple(valid_types): valid = False if not valid: type_names = [t.__name__ for t in valid_types] if len(valid_types) == 1: raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},' f' but got {type(arg_value).__name__}.') raise TypeError(f'The type of `{arg_name}` should be one of {type_names},' f' but got {type(arg_value).__name__}.') return type(arg_value).__name__ def check_typename(arg_name, arg_type, valid_types): """Check type name.""" def get_typename(t): return t.__name__ if hasattr(t, '__name__') else str(t) if isinstance(arg_type, type(mstype.tensor)): arg_type = arg_type.element_type() if arg_type in valid_types: return arg_type if isinstance(arg_type, tuple(valid_types)): return arg_type type_names = [get_typename(t) for t in valid_types] if len(valid_types) == 1: raise TypeError(f'The type of `{arg_name}` should be {type_names[0]},' f' but got {get_typename(arg_type)}.') raise TypeError(f'The type of `{arg_name}` should be one of {type_names},' f' but got {get_typename(arg_type)}.') def check_shape(arg_name, arg_value): """Check shape.""" # First, check if shape is a tuple if not isinstance(arg_value, tuple): raise TypeError(f'The type of `{arg_name}` should be one of {tuple.__name__},' f' but got {type(arg_value).__name__}.') # Second, wrap arg_value with numpy array so that it can be checked through numpy api arg_value = np.array(arg_value) # shape can not be () if arg_value.size == 0: raise ValueError('Shape can not be empty.') # shape's dimension should be 1 if arg_value.ndim != 1: raise ValueError('Shape of tensor should be 1-dim vector, but got {}-dim.'.format(arg_value.ndim)) # Thirdly, check each element's type of the shape valid_types = (int, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64) for dim_size in arg_value: if not isinstance(dim_size, valid_types) or dim_size <= 0: raise ValueError('Every dimension size of the tensor shape should be a positive integer,' ' but got {}.'.format(dim_size)) def _check_str_by_regular(target, reg=None, flag=re.ASCII): if reg is None: reg = _name_re if re.match(reg, target, flag) is None: raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag)) return True def args_type_check(*type_args, **type_kwargs): """Check whether input data type is correct.""" def type_check(func): sig = inspect.signature(func) bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments @wraps(func) def wrapper(*args, **kwargs): nonlocal bound_types bound_values = sig.bind(*args, **kwargs) argument_dict = bound_values.arguments if "kwargs" in bound_types: bound_types = bound_types["kwargs"] if "kwargs" in argument_dict: argument_dict = argument_dict["kwargs"] for name, value in argument_dict.items(): if name in bound_types: if value is not None and not isinstance(value, bound_types[name]): raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) return func(*args, **kwargs) return wrapper return type_check