diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index cb3dbc0d501c5ad13cd3fdf2644aab31078e306d..e9a928461f5bafa15f6a59058ec4455af554720e 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -17,7 +17,7 @@ import re from enum import Enum from functools import reduce from itertools import repeat -from collections import Iterable +from collections.abc import Iterable import numpy as np from mindspore import log as logger @@ -98,7 +98,7 @@ class Validator: """validator for checking input parameters""" @staticmethod - def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None): + def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): """ Method for judging relation between two int values or list/tuple made up of ints. @@ -108,8 +108,8 @@ class Validator: rel_fn = Rel.get_fns(rel) if not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') - msg_prefix = f'For {prim_name} the' if prim_name else "The" - raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') @staticmethod def check_integer(arg_name, arg_value, value, rel, prim_name): @@ -118,8 +118,17 @@ class Validator: type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) if type_mismatch or not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(value) - raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},' - f' but got {arg_value}.') + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') + return arg_value + + @staticmethod + def check_number(arg_name, arg_value, value, rel, prim_name): + """Integer value judgment.""" + rel_fn = Rel.get_fns(rel) + if not rel_fn(arg_value, value): + rel_str = Rel.get_strs(rel).format(value) + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') return arg_value @staticmethod @@ -133,9 +142,46 @@ class Validator: f' but got {arg_value}.') return arg_value + @staticmethod + def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): + """Method for checking whether a numeric value is in some range.""" + rel_fn = Rel.get_fns(rel) + if not rel_fn(arg_value, lower_limit, upper_limit): + rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.') + return arg_value + + @staticmethod + def check_string(arg_name, arg_value, valid_values, prim_name): + """Checks whether a string is in some value list""" + if isinstance(arg_value, str) and arg_value in valid_values: + return arg_value + if len(valid_values) == 1: + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},' + f' but got {arg_value}.') + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},' + f' but got {arg_value}.') + + @staticmethod + def check_pad_value_by_mode(pad_mode, padding, prim_name): + """Validates value of padding according to pad_mode""" + if pad_mode != 'pad' and padding != 0: + raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.") + return padding + + @staticmethod + def check_float_positive(arg_name, arg_value, prim_name): + """Float type judgment.""" + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + if isinstance(arg_value, float): + if arg_value > 0: + return arg_value + raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.") + raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") + @staticmethod def check_subclass(arg_name, type_, template_type, prim_name): - """Check whether some type is sublcass of another type""" + """Checks whether some type is sublcass of another type""" if not isinstance(template_type, Iterable): template_type = (template_type,) if not any([mstype.issubclass_(type_, x) for x in template_type]): @@ -143,16 +189,44 @@ class Validator: raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') + @staticmethod + def check_const_input(arg_name, arg_value, prim_name): + """Check valid value.""" + if arg_value is None: + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') + + @staticmethod + def check_scalar_type_same(args, valid_values, prim_name): + """check whether the types of inputs are the same.""" + def _check_tensor_type(arg): + arg_key, arg_val = arg + elem_type = arg_val + if not elem_type in valid_values: + raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},' + f' but `{arg_key}` is {elem_type}.') + return (arg_key, elem_type) + + def _check_types_same(arg1, arg2): + arg1_name, arg1_type = arg1 + arg2_name, arg2_type = arg2 + if arg1_type != arg2_type: + raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' + f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') + return arg1 + + elem_types = map(_check_tensor_type, args.items()) + reduce(_check_types_same, elem_types) + @staticmethod def check_tensor_type_same(args, valid_values, prim_name): - """check whether the element types of input tensors are the same.""" + """Checks whether the element types of input tensors are the same.""" def _check_tensor_type(arg): arg_key, arg_val = arg Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) elem_type = arg_val.element_type() if not elem_type in valid_values: raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' - f' but `{arg_key}` is {elem_type}.') + f' but element type of `{arg_key}` is {elem_type}.') return (arg_key, elem_type) def _check_types_same(arg1, arg2): @@ -168,8 +242,13 @@ class Validator: @staticmethod - def check_scalar_or_tensor_type_same(args, valid_values, prim_name): - """check whether the types of inputs are the same. if the input args are tensors, check their element types""" + def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): + """ + Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. + + If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. + """ + def _check_argument_type(arg): arg_key, arg_val = arg if isinstance(arg_val, type(mstype.tensor)): @@ -188,6 +267,9 @@ class Validator: arg2_type = arg2_type.element_type() elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): pass + elif allow_mix: + arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type + arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type else: excp_flag = True @@ -199,13 +281,14 @@ class Validator: @staticmethod def check_value_type(arg_name, arg_value, valid_types, prim_name): - """Check whether a values is instance of some types.""" + """Checks whether a value is instance of some types.""" + valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) def raise_error_msg(): """func for raising error message when check failed""" type_names = [t.__name__ for t in valid_types] num_types = len(valid_types) - raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be ' - f'{"one of " if num_types > 1 else ""}' + msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' + raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and @@ -216,6 +299,23 @@ class Validator: return arg_value raise_error_msg() + @staticmethod + def check_type_name(arg_name, arg_type, valid_types, prim_name): + """Checks whether a type in some specified types""" + valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) + def get_typename(t): + return t.__name__ if hasattr(t, '__name__') else str(t) + + if arg_type in valid_types: + return arg_type + type_names = [get_typename(t) for t in valid_types] + msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' + if len(valid_types) == 1: + raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' + f' but got {get_typename(arg_type)}.') + raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' + f' but got {get_typename(arg_type)}.') + class ParamValidator: """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 5507d12af89eb710d7f2841540a45217d8b77539..a694489f5a430605ce1e9739091d7b5801a86fa5 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -103,6 +103,10 @@ class Cell: def parameter_layout_dict(self): return self._parameter_layout_dict + @property + def cls_name(self): + return self.__class__.__name__ + @parameter_layout_dict.setter def parameter_layout_dict(self, value): if not isinstance(value, dict): diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index cf25f1f50ee67fe5c6d03a0bb9275df07e5f478c..0c5a160380e637b5edd05fa16c8bfe353878101b 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -15,7 +15,7 @@ """dynamic learning rate""" import math -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel @@ -43,16 +43,16 @@ def piecewise_constant_lr(milestone, learning_rates): >>> lr = piecewise_constant_lr(milestone, learning_rates) [0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01] """ - validator.check_type('milestone', milestone, (tuple, list)) - validator.check_type('learning_rates', learning_rates, (tuple, list)) + validator.check_value_type('milestone', milestone, (tuple, list), None) + validator.check_value_type('learning_rates', learning_rates, (tuple, list), None) if len(milestone) != len(learning_rates): raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.') lr = [] last_item = 0 for i, item in enumerate(milestone): - validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT) - validator.check_type(f'learning_rates[{i}]', learning_rates[i], [float]) + validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None) + validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None) if item < last_item: raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') lr += [learning_rates[i]] * (item - last_item) @@ -62,12 +62,12 @@ def piecewise_constant_lr(milestone, learning_rates): def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair): - validator.check_integer('total_step', total_step, 0, Rel.GT) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) - validator.check_float_positive('learning_rate', learning_rate) - validator.check_float_positive('decay_rate', decay_rate) - validator.check_type('is_stair', is_stair, [bool]) + validator.check_integer('total_step', total_step, 0, Rel.GT, None) + validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) + validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_positive('decay_rate', decay_rate, None) + validator.check_value_type('is_stair', is_stair, [bool], None) def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False): @@ -228,11 +228,11 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): >>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] """ - validator.check_float_positive('min_lr', min_lr) - validator.check_float_positive('max_lr', max_lr) - validator.check_integer('total_step', total_step, 0, Rel.GT) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) + validator.check_float_positive('min_lr', min_lr, None) + validator.check_float_positive('max_lr', max_lr, None) + validator.check_integer('total_step', total_step, 0, Rel.GT, None) + validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) delta = 0.5 * (max_lr - min_lr) lr = [] @@ -279,13 +279,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e >>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] """ - validator.check_float_positive('learning_rate', learning_rate) - validator.check_float_positive('end_learning_rate', end_learning_rate) - validator.check_integer('total_step', total_step, 0, Rel.GT) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) - validator.check_type('power', power, [float]) - validator.check_type('update_decay_epoch', update_decay_epoch, [bool]) + validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_positive('end_learning_rate', end_learning_rate, None) + validator.check_integer('total_step', total_step, 0, Rel.GT, None) + validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) + validator.check_value_type('power', power, [float], None) + validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) function = lambda x, y: (x, min(x, y)) if update_decay_epoch: diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 5ac52acac7195ee2f3d0ea5704ae7b9c9b717d1f..2449eea9b4a5dfefd55a58809bc577461ff87d62 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -25,7 +25,7 @@ from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register from ..cell import Cell from .activation import get_activation -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator class Dropout(Cell): @@ -73,7 +73,7 @@ class Dropout(Cell): super(Dropout, self).__init__() if keep_prob <= 0 or keep_prob > 1: raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) - validator.check_subclass("dtype", dtype, mstype.number_type) + validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) self.keep_prob = Tensor(keep_prob) self.seed0 = seed0 self.seed1 = seed1 @@ -421,7 +421,7 @@ class Pad(Cell): super(Pad, self).__init__() self.mode = mode self.paddings = paddings - validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"]) + validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], self.cls_name) if not isinstance(paddings, tuple): raise TypeError('Paddings must be tuple type.') for item in paddings: diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index dfa8e664692947f4c385172175eeda13f4e949b2..24b94f2f3ca8f81cd3cafccf96305970fe5a4b40 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -19,7 +19,7 @@ from mindspore.ops import operations as P from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from ..cell import Cell -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator class Embedding(Cell): @@ -59,7 +59,7 @@ class Embedding(Cell): """ def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): super(Embedding, self).__init__() - validator.check_subclass("dtype", dtype, mstype.number_type) + validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) self.vocab_size = vocab_size self.embedding_size = embedding_size self.use_one_hot = use_one_hot diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 72c4c6d8e2d7bacf0a0fb6f7641b7db621aaeaab..b46ac4cd6ec91c48cd4cd77570699dd61ac244ed 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -19,7 +19,7 @@ from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops.primitive import constexpr -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from ..cell import Cell @@ -134,15 +134,15 @@ class SSIM(Cell): """ def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): super(SSIM, self).__init__() - validator.check_type('max_val', max_val, [int, float]) - validator.check('max_val', max_val, '', 0.0, Rel.GT) + validator.check_value_type('max_val', max_val, [int, float], self.cls_name) + validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val - self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE) - self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma) - validator.check_type('k1', k1, [float]) - self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_type('k2', k2, [float]) - self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER) + self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) + self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) + validator.check_value_type('k1', k1, [float], self.cls_name) + self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) + validator.check_value_type('k2', k2, [float], self.cls_name) + self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) def construct(self, img1, img2): @@ -231,8 +231,8 @@ class PSNR(Cell): """ def __init__(self, max_val=1.0): super(PSNR, self).__init__() - validator.check_type('max_val', max_val, [int, float]) - validator.check('max_val', max_val, '', 0.0, Rel.GT) + validator.check_value_type('max_val', max_val, [int, float], self.cls_name) + validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val def construct(self, img1, img2): diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index cef926d365e3092d6fe05afc55ce43b75caff664..84c156a1c2889db28ae2e5b10cf22742033d5841 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -17,7 +17,7 @@ from mindspore.ops import operations as P from mindspore.nn.cell import Cell from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator class LSTM(Cell): @@ -114,7 +114,7 @@ class LSTM(Cell): self.hidden_size = hidden_size self.num_layers = num_layers self.has_bias = has_bias - self.batch_first = validator.check_type("batch_first", batch_first, [bool]) + self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) self.dropout = float(dropout) self.bidirectional = bidirectional diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 746b6d240f9685bc4b5357c7d8c8e1b1bed3625d..53d97807cff0b4722a397eee5c5a852ec0d9f7ce 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -14,8 +14,7 @@ # ============================================================================ """pooling""" from mindspore.ops import operations as P -from mindspore._checkparam import ParamValidator as validator -from mindspore._checkparam import Rel +from mindspore._checkparam import Validator as validator from ... import context from ..cell import Cell @@ -24,35 +23,27 @@ class _PoolNd(Cell): """N-D AvgPool""" def __init__(self, kernel_size, stride, pad_mode): - name = self.__class__.__name__ super(_PoolNd, self).__init__() - validator.check_type('kernel_size', kernel_size, [int, tuple]) - validator.check_type('stride', stride, [int, tuple]) - self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) - - if isinstance(kernel_size, int): - validator.check_integer("kernel_size", kernel_size, 1, Rel.GE) - else: - if (len(kernel_size) != 2 or - (not isinstance(kernel_size[0], int)) or - (not isinstance(kernel_size[1], int)) or - kernel_size[0] <= 0 or - kernel_size[1] <= 0): - raise ValueError(f'The kernel_size passed to cell {name} should be an positive int number or' - f'a tuple of two positive int numbers, but got {kernel_size}') - self.kernel_size = kernel_size - - if isinstance(stride, int): - validator.check_integer("stride", stride, 1, Rel.GE) - else: - if (len(stride) != 2 or - (not isinstance(stride[0], int)) or - (not isinstance(stride[1], int)) or - stride[0] <= 0 or - stride[1] <= 0): - raise ValueError(f'The stride passed to cell {name} should be an positive int number or' - f'a tuple of two positive int numbers, but got {stride}') - self.stride = stride + self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name) + + def _check_int_or_tuple(arg_name, arg_value): + validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name) + error_msg = f'For \'{self.cls_name}\' the {arg_name} should be an positive int number or ' \ + f'a tuple of two positive int numbers, but got {arg_value}' + if isinstance(arg_value, int): + if arg_value <= 0: + raise ValueError(error_msg) + elif len(arg_value) == 2: + for item in arg_value: + if isinstance(item, int) and item > 0: + continue + raise ValueError(error_msg) + else: + raise ValueError(error_msg) + return arg_value + + self.kernel_size = _check_int_or_tuple('kernel_size', kernel_size) + self.stride = _check_int_or_tuple('stride', stride) def construct(self, *inputs): pass diff --git a/mindspore/nn/metrics/fbeta.py b/mindspore/nn/metrics/fbeta.py index 68df4318b04246e3a4bc7b18987b49800fa2c52b..3ae5c44bc2fbafea507a792d52ee25935af6df4b 100755 --- a/mindspore/nn/metrics/fbeta.py +++ b/mindspore/nn/metrics/fbeta.py @@ -15,7 +15,7 @@ """Fbeta.""" import sys import numpy as np -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .metric import Metric @@ -104,7 +104,7 @@ class Fbeta(Metric): Returns: Float, computed result. """ - validator.check_type("average", average, [bool]) + validator.check_value_type("average", average, [bool], self.__class__.__name__) if self._class_num == 0: raise RuntimeError('Input number of samples can not be 0.') diff --git a/mindspore/nn/metrics/precision.py b/mindspore/nn/metrics/precision.py index ad7b6c576fc2681f0d0b0f66c6beed8599522b6d..633b9f8e2cdfd67df2d2604347376cdfd373f55b 100644 --- a/mindspore/nn/metrics/precision.py +++ b/mindspore/nn/metrics/precision.py @@ -17,7 +17,7 @@ import sys import numpy as np -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .evaluation import EvaluationBase @@ -136,7 +136,7 @@ class Precision(EvaluationBase): if self._class_num == 0: raise RuntimeError('Input number of samples can not be 0.') - validator.check_type("average", average, [bool]) + validator.check_value_type("average", average, [bool], self.__class__.__name__) result = self._true_positives / (self._positives + self.eps) if average: diff --git a/mindspore/nn/metrics/recall.py b/mindspore/nn/metrics/recall.py index 45ebf0d7db7f1db215cedeba9a9201ea6042312c..da06321aa35b6c20f7a7ce15c98fa79e2fbb8ad3 100644 --- a/mindspore/nn/metrics/recall.py +++ b/mindspore/nn/metrics/recall.py @@ -17,7 +17,7 @@ import sys import numpy as np -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .evaluation import EvaluationBase @@ -136,7 +136,7 @@ class Recall(EvaluationBase): if self._class_num == 0: raise RuntimeError('Input number of samples can not be 0.') - validator.check_type("average", average, [bool]) + validator.check_value_type("average", average, [bool], self.__class__.__name__) result = self._true_positives / (self._actual_positives + self.eps) if average: diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index eb4e33751f51667c3915678c4c5d3e807eb8d25e..65f8ec678be89b59c7717ab56f8cb79c04273aae 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -22,7 +22,7 @@ from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer @@ -78,16 +78,16 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad return next_v -def _check_param_value(beta1, beta2, eps, weight_decay): +def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): """Check the type of inputs.""" - validator.check_type("beta1", beta1, [float]) - validator.check_type("beta2", beta2, [float]) - validator.check_type("eps", eps, [float]) - validator.check_type("weight_dacay", weight_decay, [float]) - validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", @@ -168,11 +168,11 @@ class Adam(Optimizer): use_nesterov=False, weight_decay=0.0, loss_scale=1.0, decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) - _check_param_value(beta1, beta2, eps, weight_decay) - validator.check_type("use_locking", use_locking, [bool]) - validator.check_type("use_nesterov", use_nesterov, [bool]) - validator.check_type("loss_scale", loss_scale, [float]) - validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT) + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) + validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) + validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) + validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) + validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) self.beta1 = Tensor(beta1, mstype.float32) self.beta2 = Tensor(beta2, mstype.float32) @@ -241,7 +241,7 @@ class AdamWeightDecay(Optimizer): """ def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): super(AdamWeightDecay, self).__init__(learning_rate, params) - _check_param_value(beta1, beta2, eps, weight_decay) + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) self.lr = Tensor(np.array([learning_rate]).astype(np.float32)) self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) @@ -304,7 +304,7 @@ class AdamWeightDecayDynamicLR(Optimizer): eps=1e-6, weight_decay=0.0): super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) - _check_param_value(beta1, beta2, eps, weight_decay) + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) # turn them to scalar when me support scalar/tensor mix operations self.global_step = Parameter(initializer(0, [1]), name="global_step") diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index ee8fc9355fd01cda872d656d8e4e40fa8435b8ae..d08dd6cf4c93f70c8c46ff42bf6179e296669873 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -18,7 +18,7 @@ from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.common import Tensor import mindspore.common.dtype as mstype -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer, apply_decay, grad_scale @@ -30,29 +30,30 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) return success -def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0): - validator.check_type("initial_accum", initial_accum, [float]) - validator.check("initial_accum", initial_accum, "", 0.0, Rel.GE) +def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0, + prim_name=None): + validator.check_value_type("initial_accum", initial_accum, [float], prim_name) + validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name) - validator.check_type("learning_rate", learning_rate, [float]) - validator.check("learning_rate", learning_rate, "", 0.0, Rel.GT) + validator.check_value_type("learning_rate", learning_rate, [float], prim_name) + validator.check_number("learning_rate", learning_rate, 0.0, Rel.GT, prim_name) - validator.check_type("lr_power", lr_power, [float]) - validator.check("lr_power", lr_power, "", 0.0, Rel.LE) + validator.check_value_type("lr_power", lr_power, [float], prim_name) + validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name) - validator.check_type("l1", l1, [float]) - validator.check("l1", l1, "", 0.0, Rel.GE) + validator.check_value_type("l1", l1, [float], prim_name) + validator.check_number("l1", l1, 0.0, Rel.GE, prim_name) - validator.check_type("l2", l2, [float]) - validator.check("l2", l2, "", 0.0, Rel.GE) + validator.check_value_type("l2", l2, [float], prim_name) + validator.check_number("l2", l2, 0.0, Rel.GE, prim_name) - validator.check_type("use_locking", use_locking, [bool]) + validator.check_value_type("use_locking", use_locking, [bool], prim_name) - validator.check_type("loss_scale", loss_scale, [float]) - validator.check("loss_scale", loss_scale, "", 1.0, Rel.GE) + validator.check_value_type("loss_scale", loss_scale, [float], prim_name) + validator.check_number("loss_scale", loss_scale, 1.0, Rel.GE, prim_name) - validator.check_type("weight_decay", weight_decay, [float]) - validator.check("weight_decay", weight_decay, "", 0.0, Rel.GE) + validator.check_value_type("weight_decay", weight_decay, [float], prim_name) + validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name) class FTRL(Optimizer): @@ -94,7 +95,8 @@ class FTRL(Optimizer): use_locking=False, loss_scale=1.0, weight_decay=0.0): super(FTRL, self).__init__(learning_rate, params) - _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay) + _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay, + self.cls_name) self.moments = self.parameters.clone(prefix="moments", init=initial_accum) self.linear = self.parameters.clone(prefix="linear", init='zeros') self.l1 = l1 diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index e74d6fc6a811573dc5992dfaf8d8dffcee0da26f..afcbf8cda47c3dfd7ced96f6a97c9f706383870a 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -21,7 +21,7 @@ from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer from .. import layer @@ -109,23 +109,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para def _check_param_value(decay_steps, warmup_steps, start_learning_rate, - end_learning_rate, power, beta1, beta2, eps, weight_decay): + end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): """Check the type of inputs.""" - validator.check_type("decay_steps", decay_steps, [int]) - validator.check_type("warmup_steps", warmup_steps, [int]) - validator.check_type("start_learning_rate", start_learning_rate, [float]) - validator.check_type("end_learning_rate", end_learning_rate, [float]) - validator.check_type("power", power, [float]) - validator.check_type("beta1", beta1, [float]) - validator.check_type("beta2", beta2, [float]) - validator.check_type("eps", eps, [float]) - validator.check_type("weight_dacay", weight_decay, [float]) - validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT) - validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) + validator.check_value_type("decay_steps", decay_steps, [int], prim_name) + validator.check_value_type("warmup_steps", warmup_steps, [int], prim_name) + validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name) + validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) + validator.check_value_type("power", power, [float], prim_name) + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) + validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT, prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) class Lamb(Optimizer): @@ -182,7 +182,7 @@ class Lamb(Optimizer): super(Lamb, self).__init__(start_learning_rate, params) _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, - power, beta1, beta2, eps, weight_decay) + power, beta1, beta2, eps, weight_decay, self.cls_name) # turn them to scalar when me support scalar/tensor mix operations self.global_step = Parameter(initializer(0, [1]), name="global_step") diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 00d3fd3b7bd571b9822a2402782980c062b65587..5738044532c73564a24cb87912e6f01d7a444177 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P from mindspore.nn.cell import Cell from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.initializer import initializer -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore.common.tensor import Tensor from mindspore import log as logger @@ -63,7 +63,7 @@ class Optimizer(Cell): self.gather = None self.assignadd = None self.global_step = None - validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT) + validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) else: self.dynamic_lr = True self.gather = P.GatherV2() diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index a68dc6f7c44ce6df9e220365d41acf2617abad5e..97d7538a26d91b4ea67c7b5bb6e670d2c8e3b459 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -14,7 +14,7 @@ # ============================================================================ """rmsprop""" from mindspore.ops import functional as F, composite as C, operations as P -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .optimizer import Optimizer rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") @@ -144,8 +144,8 @@ class RMSProp(Optimizer): self.decay = decay self.epsilon = epsilon - validator.check_type("use_locking", use_locking, [bool]) - validator.check_type("centered", centered, [bool]) + validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) + validator.check_value_type("centered", centered, [bool], self.cls_name) self.centered = centered if centered: self.opt = P.ApplyCenteredRMSProp(use_locking) diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 983be4bf80b5923509001783e6ce6f7d544576dd..db0775e023e22510cc4ec49888d4a1d0882508d6 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -15,7 +15,7 @@ """sgd""" from mindspore.ops import functional as F, composite as C, operations as P from mindspore.common.parameter import Parameter -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .optimizer import Optimizer sgd_opt = C.MultitypeFuncGraph("sgd_opt") @@ -100,7 +100,7 @@ class SGD(Optimizer): raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) self.dampening = dampening - validator.check_type("nesterov", nesterov, [bool]) + validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) self.nesterov = nesterov self.opt = P.SGD(dampening, weight_decay, nesterov) diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index e4b0bfdbfedc78ee4cb83e33c0a2046dc1eb445e..752b367023a2afa6e3e2d370e29a3f55d62852bf 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -19,7 +19,7 @@ import os import json import inspect from mindspore._c_expression import Oplib -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator # path of built-in op info register. BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl" @@ -43,7 +43,7 @@ def op_info_register(op_info): op_info_real = json.dumps(op_info) else: op_info_real = op_info - validator.check_type("op_info", op_info_real, [str]) + validator.check_value_type("op_info", op_info_real, [str], None) op_lib = Oplib() file_path = os.path.realpath(inspect.getfile(func)) # keep the path custom ops implementation. diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index c4c115ef27f4f13f4dab6a846e99c7e76ed7256b..66e08874b288d063c03e64b582e0b778038bea6f 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -16,7 +16,7 @@ from easydict import EasyDict as edict from .. import nn -from .._checkparam import ParamValidator as validator +from .._checkparam import Validator as validator from .._checkparam import Rel from ..common import dtype as mstype from ..nn.wrap.cell_wrapper import _VirtualDatasetCell @@ -73,14 +73,14 @@ def _check_kwargs(key_words): raise ValueError(f"Unsupported arg '{arg}'") if 'cast_model_type' in key_words: - validator.check('cast_model_type', key_words['cast_model_type'], - [mstype.float16, mstype.float32], Rel.IN) + validator.check_type_name('cast_model_type', key_words['cast_model_type'], + [mstype.float16, mstype.float32], None) if 'keep_batchnorm_fp32' in key_words: - validator.check_isinstance('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool) + validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool, None) if 'loss_scale_manager' in key_words: loss_scale_manager = key_words['loss_scale_manager'] if loss_scale_manager: - validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager) + validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager, None) def _add_loss_network(network, loss_fn, cast_model_type): @@ -97,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type): label = _mp_cast_helper(mstype.float32, label) return self._loss_fn(F.cast(out, mstype.float32), label) - validator.check_isinstance('loss_fn', loss_fn, nn.Cell) + validator.check_value_type('loss_fn', loss_fn, nn.Cell, None) if cast_model_type == mstype.float16: network = WithLossCell(network, loss_fn) else: @@ -126,9 +126,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else scale the loss by LossScaleManager. If set, overwrite the level setting. """ - validator.check_isinstance('network', network, nn.Cell) - validator.check_isinstance('optimizer', optimizer, nn.Optimizer) - validator.check('level', level, "", ['O0', 'O2'], Rel.IN) + validator.check_value_type('network', network, nn.Cell, None) + validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) + validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None) _check_kwargs(kwargs) config = dict(_config_level[level], **kwargs) config = edict(config) diff --git a/mindspore/train/loss_scale_manager.py b/mindspore/train/loss_scale_manager.py index 5650c58f6207df956c2a2ee2a1665079b14d7e6c..c8c28a72cb91c96318323e9ab11c437fbba21b63 100644 --- a/mindspore/train/loss_scale_manager.py +++ b/mindspore/train/loss_scale_manager.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """Loss scale manager abstract class.""" -from .._checkparam import ParamValidator as validator +from .._checkparam import Validator as validator from .._checkparam import Rel from .. import nn @@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager): if init_loss_scale < 1.0: raise ValueError("Loss scale value should be > 1") self.loss_scale = init_loss_scale - validator.check_integer("scale_window", scale_window, 0, Rel.GT) + validator.check_integer("scale_window", scale_window, 0, Rel.GT, self.__class__.__name__) self.scale_window = scale_window if scale_factor <= 0: raise ValueError("Scale factor should be > 1") diff --git a/tests/ut/python/nn/test_dynamic_lr.py b/tests/ut/python/nn/test_dynamic_lr.py index cb959956d6b270f31d59f3d0fe62865a391c8aa5..96f9d5afdeddc314c1be8e7959dec1a760a0dd51 100644 --- a/tests/ut/python/nn/test_dynamic_lr.py +++ b/tests/ut/python/nn/test_dynamic_lr.py @@ -32,7 +32,7 @@ power = 0.5 class TestInputs: def test_milestone1(self): milestone1 = 1 - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.piecewise_constant_lr(milestone1, learning_rates) def test_milestone2(self): @@ -46,12 +46,12 @@ class TestInputs: def test_learning_rates1(self): lr = True - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.piecewise_constant_lr(milestone, lr) def test_learning_rates2(self): lr = [1, 2, 1] - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.piecewise_constant_lr(milestone, lr) def test_learning_rate_type(self): @@ -158,7 +158,7 @@ class TestInputs: def test_is_stair(self): is_stair = 1 - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair) def test_min_lr_type(self): @@ -183,12 +183,12 @@ class TestInputs: def test_power(self): power1 = True - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1) def test_update_decay_epoch(self): update_decay_epoch = 1 - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power, update_decay_epoch) diff --git a/tests/ut/python/nn/test_psnr.py b/tests/ut/python/nn/test_psnr.py index 5a908b308dd105fdbb18c08e06492b86910cefb0..32e7b570aa275d38cc33a6425bb1ee55c03b06a9 100644 --- a/tests/ut/python/nn/test_psnr.py +++ b/tests/ut/python/nn/test_psnr.py @@ -52,7 +52,7 @@ def test_psnr_max_val_negative(): def test_psnr_max_val_bool(): max_val = True - with pytest.raises(ValueError): + with pytest.raises(TypeError): net = PSNRNet(max_val) def test_psnr_max_val_zero(): diff --git a/tests/ut/python/nn/test_ssim.py b/tests/ut/python/nn/test_ssim.py index a698b59f69cacc81c6fb958fe48e137c4e59d49f..cf946a16172dea06edbd64cda8075dbe68e00a2b 100644 --- a/tests/ut/python/nn/test_ssim.py +++ b/tests/ut/python/nn/test_ssim.py @@ -51,7 +51,7 @@ def test_ssim_max_val_negative(): def test_ssim_max_val_bool(): max_val = True - with pytest.raises(ValueError): + with pytest.raises(TypeError): net = SSIMNet(max_val) def test_ssim_max_val_zero(): @@ -92,4 +92,4 @@ def test_ssim_k1_k2_wrong_value(): with pytest.raises(ValueError): net = SSIMNet(k2=0.0) with pytest.raises(ValueError): - net = SSIMNet(k2=-1.0) \ No newline at end of file + net = SSIMNet(k2=-1.0) diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index bb2bb3ea9f7d00929be5d8a9b444a0f94eb507fa..09a4248c19caef49d870d2a57229f468be7787fc 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -594,14 +594,14 @@ test_cases_for_verify_exception = [ ('MaxPool2d_ValueError_2', { 'block': ( lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"), - {'exception': ValueError}, + {'exception': TypeError}, ), 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], }), ('MaxPool2d_ValueError_3', { 'block': ( lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"), - {'exception': ValueError}, + {'exception': TypeError}, ), 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], }), diff --git a/tests/ut/python/pynative_mode/nn/test_pooling.py b/tests/ut/python/pynative_mode/nn/test_pooling.py index bb1822f8a8a4a0f286509447fa74532a3ad9ddc7..f8df3ada3f24ec4e51395419f46bc6135be8106f 100644 --- a/tests/ut/python/pynative_mode/nn/test_pooling.py +++ b/tests/ut/python/pynative_mode/nn/test_pooling.py @@ -38,7 +38,7 @@ def test_avgpool2d_error_input(): """ test_avgpool2d_error_input """ kernel_size = 5 stride = 2.3 - with pytest.raises(ValueError): + with pytest.raises(TypeError): nn.AvgPool2d(kernel_size, stride)