提交 c3ec9712 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!403 add cell class name to error message

Merge pull request !403 from fary86/add_cell_name_to_error_message_for_nn_layer
......@@ -17,7 +17,7 @@ import re
from enum import Enum
from functools import reduce
from itertools import repeat
from collections import Iterable
from collections.abc import Iterable
import numpy as np
from mindspore import log as logger
......@@ -98,7 +98,7 @@ class Validator:
"""validator for checking input parameters"""
@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None):
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError):
"""
Method for judging relation between two int values or list/tuple made up of ints.
......@@ -108,8 +108,8 @@ class Validator:
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
msg_prefix = f'For {prim_name} the' if prim_name else "The"
raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.')
@staticmethod
def check_integer(arg_name, arg_value, value, rel, prim_name):
......@@ -118,8 +118,17 @@ class Validator:
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},'
f' but got {arg_value}.')
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_number(arg_name, arg_value, value, rel, prim_name):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
......@@ -133,9 +142,46 @@ class Validator:
f' but got {arg_value}.')
return arg_value
@staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name):
"""Method for checking whether a numeric value is in some range."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_string(arg_name, arg_value, valid_values, prim_name):
"""Checks whether a string is in some value list"""
if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value
if len(valid_values) == 1:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},'
f' but got {arg_value}.')
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},'
f' but got {arg_value}.')
@staticmethod
def check_pad_value_by_mode(pad_mode, padding, prim_name):
"""Validates value of padding according to pad_mode"""
if pad_mode != 'pad' and padding != 0:
raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.")
return padding
@staticmethod
def check_float_positive(arg_name, arg_value, prim_name):
"""Float type judgment."""
msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The"
if isinstance(arg_value, float):
if arg_value > 0:
return arg_value
raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.")
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
@staticmethod
def check_subclass(arg_name, type_, template_type, prim_name):
"""Check whether some type is sublcass of another type"""
"""Checks whether some type is sublcass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
......@@ -143,16 +189,44 @@ class Validator:
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_const_input(arg_name, arg_value, prim_name):
"""Check valid value."""
if arg_value is None:
raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.')
@staticmethod
def check_scalar_type_same(args, valid_values, prim_name):
"""check whether the types of inputs are the same."""
def _check_tensor_type(arg):
arg_key, arg_val = arg
elem_type = arg_val
if not elem_type in valid_values:
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {elem_type}.')
return (arg_key, elem_type)
def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
if arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1
elem_types = map(_check_tensor_type, args.items())
reduce(_check_types_same, elem_types)
@staticmethod
def check_tensor_type_same(args, valid_values, prim_name):
"""check whether the element types of input tensors are the same."""
"""Checks whether the element types of input tensors are the same."""
def _check_tensor_type(arg):
arg_key, arg_val = arg
Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name)
elem_type = arg_val.element_type()
if not elem_type in valid_values:
raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},'
f' but `{arg_key}` is {elem_type}.')
f' but element type of `{arg_key}` is {elem_type}.')
return (arg_key, elem_type)
def _check_types_same(arg1, arg2):
......@@ -168,8 +242,13 @@ class Validator:
@staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name):
"""check whether the types of inputs are the same. if the input args are tensors, check their element types"""
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
"""
Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
"""
def _check_argument_type(arg):
arg_key, arg_val = arg
if isinstance(arg_val, type(mstype.tensor)):
......@@ -188,6 +267,9 @@ class Validator:
arg2_type = arg2_type.element_type()
elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))):
pass
elif allow_mix:
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
else:
excp_flag = True
......@@ -199,13 +281,14 @@ class Validator:
@staticmethod
def check_value_type(arg_name, arg_value, valid_types, prim_name):
"""Check whether a values is instance of some types."""
"""Checks whether a value is instance of some types."""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be '
f'{"one of " if num_types > 1 else ""}'
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
# Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
......@@ -216,6 +299,23 @@ class Validator:
return arg_value
raise_error_msg()
@staticmethod
def check_type_name(arg_name, arg_type, valid_types, prim_name):
"""Checks whether a type in some specified types"""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
if len(valid_types) == 1:
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
......
......@@ -103,6 +103,10 @@ class Cell:
def parameter_layout_dict(self):
return self._parameter_layout_dict
@property
def cls_name(self):
return self.__class__.__name__
@parameter_layout_dict.setter
def parameter_layout_dict(self, value):
if not isinstance(value, dict):
......
......@@ -15,7 +15,7 @@
"""dynamic learning rate"""
import math
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
......@@ -43,16 +43,16 @@ def piecewise_constant_lr(milestone, learning_rates):
>>> lr = piecewise_constant_lr(milestone, learning_rates)
[0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01]
"""
validator.check_type('milestone', milestone, (tuple, list))
validator.check_type('learning_rates', learning_rates, (tuple, list))
validator.check_value_type('milestone', milestone, (tuple, list), None)
validator.check_value_type('learning_rates', learning_rates, (tuple, list), None)
if len(milestone) != len(learning_rates):
raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.')
lr = []
last_item = 0
for i, item in enumerate(milestone):
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT)
validator.check_type(f'learning_rates[{i}]', learning_rates[i], [float])
validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None)
validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None)
if item < last_item:
raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
lr += [learning_rates[i]] * (item - last_item)
......@@ -62,12 +62,12 @@ def piecewise_constant_lr(milestone, learning_rates):
def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair):
validator.check_integer('total_step', total_step, 0, Rel.GT)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)
validator.check_float_positive('learning_rate', learning_rate)
validator.check_float_positive('decay_rate', decay_rate)
validator.check_type('is_stair', is_stair, [bool])
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_positive('decay_rate', decay_rate, None)
validator.check_value_type('is_stair', is_stair, [bool], None)
def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
......@@ -228,11 +228,11 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
>>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
[0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
"""
validator.check_float_positive('min_lr', min_lr)
validator.check_float_positive('max_lr', max_lr)
validator.check_integer('total_step', total_step, 0, Rel.GT)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)
validator.check_float_positive('min_lr', min_lr, None)
validator.check_float_positive('max_lr', max_lr, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
delta = 0.5 * (max_lr - min_lr)
lr = []
......@@ -279,13 +279,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e
>>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
[0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
"""
validator.check_float_positive('learning_rate', learning_rate)
validator.check_float_positive('end_learning_rate', end_learning_rate)
validator.check_integer('total_step', total_step, 0, Rel.GT)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT)
validator.check_type('power', power, [float])
validator.check_type('update_decay_epoch', update_decay_epoch, [bool])
validator.check_float_positive('learning_rate', learning_rate, None)
validator.check_float_positive('end_learning_rate', end_learning_rate, None)
validator.check_integer('total_step', total_step, 0, Rel.GT, None)
validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None)
validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None)
validator.check_value_type('power', power, [float], None)
validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None)
function = lambda x, y: (x, min(x, y))
if update_decay_epoch:
......
......@@ -25,7 +25,7 @@ from mindspore.common.parameter import Parameter
from mindspore._extends import cell_attr_register
from ..cell import Cell
from .activation import get_activation
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator as validator
class Dropout(Cell):
......@@ -73,7 +73,7 @@ class Dropout(Cell):
super(Dropout, self).__init__()
if keep_prob <= 0 or keep_prob > 1:
raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob))
validator.check_subclass("dtype", dtype, mstype.number_type)
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
self.keep_prob = Tensor(keep_prob)
self.seed0 = seed0
self.seed1 = seed1
......@@ -421,7 +421,7 @@ class Pad(Cell):
super(Pad, self).__init__()
self.mode = mode
self.paddings = paddings
validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"])
validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], self.cls_name)
if not isinstance(paddings, tuple):
raise TypeError('Paddings must be tuple type.')
for item in paddings:
......
......@@ -19,7 +19,7 @@ from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from ..cell import Cell
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator as validator
class Embedding(Cell):
......@@ -59,7 +59,7 @@ class Embedding(Cell):
"""
def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32):
super(Embedding, self).__init__()
validator.check_subclass("dtype", dtype, mstype.number_type)
validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.use_one_hot = use_one_hot
......
......@@ -19,7 +19,7 @@ from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from ..cell import Cell
......@@ -134,15 +134,15 @@ class SSIM(Cell):
"""
def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
super(SSIM, self).__init__()
validator.check_type('max_val', max_val, [int, float])
validator.check('max_val', max_val, '', 0.0, Rel.GT)
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE)
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma)
validator.check_type('k1', k1, [float])
self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER)
validator.check_type('k2', k2, [float])
self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER)
self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name)
self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name)
validator.check_value_type('k1', k1, [float], self.cls_name)
self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name)
validator.check_value_type('k2', k2, [float], self.cls_name)
self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name)
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size)
def construct(self, img1, img2):
......@@ -231,8 +231,8 @@ class PSNR(Cell):
"""
def __init__(self, max_val=1.0):
super(PSNR, self).__init__()
validator.check_type('max_val', max_val, [int, float])
validator.check('max_val', max_val, '', 0.0, Rel.GT)
validator.check_value_type('max_val', max_val, [int, float], self.cls_name)
validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name)
self.max_val = max_val
def construct(self, img1, img2):
......
......@@ -17,7 +17,7 @@ from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
class LSTM(Cell):
......@@ -114,7 +114,7 @@ class LSTM(Cell):
self.hidden_size = hidden_size
self.num_layers = num_layers
self.has_bias = has_bias
self.batch_first = validator.check_type("batch_first", batch_first, [bool])
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
self.dropout = float(dropout)
self.bidirectional = bidirectional
......
......@@ -14,8 +14,7 @@
# ============================================================================
"""pooling"""
from mindspore.ops import operations as P
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Rel
from mindspore._checkparam import Validator as validator
from ... import context
from ..cell import Cell
......@@ -24,35 +23,27 @@ class _PoolNd(Cell):
"""N-D AvgPool"""
def __init__(self, kernel_size, stride, pad_mode):
name = self.__class__.__name__
super(_PoolNd, self).__init__()
validator.check_type('kernel_size', kernel_size, [int, tuple])
validator.check_type('stride', stride, [int, tuple])
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'])
if isinstance(kernel_size, int):
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE)
else:
if (len(kernel_size) != 2 or
(not isinstance(kernel_size[0], int)) or
(not isinstance(kernel_size[1], int)) or
kernel_size[0] <= 0 or
kernel_size[1] <= 0):
raise ValueError(f'The kernel_size passed to cell {name} should be an positive int number or'
f'a tuple of two positive int numbers, but got {kernel_size}')
self.kernel_size = kernel_size
if isinstance(stride, int):
validator.check_integer("stride", stride, 1, Rel.GE)
else:
if (len(stride) != 2 or
(not isinstance(stride[0], int)) or
(not isinstance(stride[1], int)) or
stride[0] <= 0 or
stride[1] <= 0):
raise ValueError(f'The stride passed to cell {name} should be an positive int number or'
f'a tuple of two positive int numbers, but got {stride}')
self.stride = stride
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name)
def _check_int_or_tuple(arg_name, arg_value):
validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name)
error_msg = f'For \'{self.cls_name}\' the {arg_name} should be an positive int number or ' \
f'a tuple of two positive int numbers, but got {arg_value}'
if isinstance(arg_value, int):
if arg_value <= 0:
raise ValueError(error_msg)
elif len(arg_value) == 2:
for item in arg_value:
if isinstance(item, int) and item > 0:
continue
raise ValueError(error_msg)
else:
raise ValueError(error_msg)
return arg_value
self.kernel_size = _check_int_or_tuple('kernel_size', kernel_size)
self.stride = _check_int_or_tuple('stride', stride)
def construct(self, *inputs):
pass
......
......@@ -15,7 +15,7 @@
"""Fbeta."""
import sys
import numpy as np
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
from .metric import Metric
......@@ -104,7 +104,7 @@ class Fbeta(Metric):
Returns:
Float, computed result.
"""
validator.check_type("average", average, [bool])
validator.check_value_type("average", average, [bool], self.__class__.__name__)
if self._class_num == 0:
raise RuntimeError('Input number of samples can not be 0.')
......
......@@ -17,7 +17,7 @@ import sys
import numpy as np
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
from .evaluation import EvaluationBase
......@@ -136,7 +136,7 @@ class Precision(EvaluationBase):
if self._class_num == 0:
raise RuntimeError('Input number of samples can not be 0.')
validator.check_type("average", average, [bool])
validator.check_value_type("average", average, [bool], self.__class__.__name__)
result = self._true_positives / (self._positives + self.eps)
if average:
......
......@@ -17,7 +17,7 @@ import sys
import numpy as np
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
from .evaluation import EvaluationBase
......@@ -136,7 +136,7 @@ class Recall(EvaluationBase):
if self._class_num == 0:
raise RuntimeError('Input number of samples can not be 0.')
validator.check_type("average", average, [bool])
validator.check_value_type("average", average, [bool], self.__class__.__name__)
result = self._true_positives / (self._actual_positives + self.eps)
if average:
......
......@@ -22,7 +22,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer
......@@ -78,16 +78,16 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
return next_v
def _check_param_value(beta1, beta2, eps, weight_decay):
def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs."""
validator.check_type("beta1", beta1, [float])
validator.check_type("beta2", beta2, [float])
validator.check_type("eps", eps, [float])
validator.check_type("weight_dacay", weight_decay, [float])
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT)
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
@adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
......@@ -168,11 +168,11 @@ class Adam(Optimizer):
use_nesterov=False, weight_decay=0.0, loss_scale=1.0,
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
_check_param_value(beta1, beta2, eps, weight_decay)
validator.check_type("use_locking", use_locking, [bool])
validator.check_type("use_nesterov", use_nesterov, [bool])
validator.check_type("loss_scale", loss_scale, [float])
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT)
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
self.beta1 = Tensor(beta1, mstype.float32)
self.beta2 = Tensor(beta2, mstype.float32)
......@@ -241,7 +241,7 @@ class AdamWeightDecay(Optimizer):
"""
def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
super(AdamWeightDecay, self).__init__(learning_rate, params)
_check_param_value(beta1, beta2, eps, weight_decay)
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
self.lr = Tensor(np.array([learning_rate]).astype(np.float32))
self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
......@@ -304,7 +304,7 @@ class AdamWeightDecayDynamicLR(Optimizer):
eps=1e-6,
weight_decay=0.0):
super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
_check_param_value(beta1, beta2, eps, weight_decay)
_check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
# turn them to scalar when me support scalar/tensor mix operations
self.global_step = Parameter(initializer(0, [1]), name="global_step")
......
......@@ -18,7 +18,7 @@ from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer, apply_decay, grad_scale
......@@ -30,29 +30,30 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig
success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power))
return success
def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0):
validator.check_type("initial_accum", initial_accum, [float])
validator.check("initial_accum", initial_accum, "", 0.0, Rel.GE)
def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0,
prim_name=None):
validator.check_value_type("initial_accum", initial_accum, [float], prim_name)
validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name)
validator.check_type("learning_rate", learning_rate, [float])
validator.check("learning_rate", learning_rate, "", 0.0, Rel.GT)
validator.check_value_type("learning_rate", learning_rate, [float], prim_name)
validator.check_number("learning_rate", learning_rate, 0.0, Rel.GT, prim_name)
validator.check_type("lr_power", lr_power, [float])
validator.check("lr_power", lr_power, "", 0.0, Rel.LE)
validator.check_value_type("lr_power", lr_power, [float], prim_name)
validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name)
validator.check_type("l1", l1, [float])
validator.check("l1", l1, "", 0.0, Rel.GE)
validator.check_value_type("l1", l1, [float], prim_name)
validator.check_number("l1", l1, 0.0, Rel.GE, prim_name)
validator.check_type("l2", l2, [float])
validator.check("l2", l2, "", 0.0, Rel.GE)
validator.check_value_type("l2", l2, [float], prim_name)
validator.check_number("l2", l2, 0.0, Rel.GE, prim_name)
validator.check_type("use_locking", use_locking, [bool])
validator.check_value_type("use_locking", use_locking, [bool], prim_name)
validator.check_type("loss_scale", loss_scale, [float])
validator.check("loss_scale", loss_scale, "", 1.0, Rel.GE)
validator.check_value_type("loss_scale", loss_scale, [float], prim_name)
validator.check_number("loss_scale", loss_scale, 1.0, Rel.GE, prim_name)
validator.check_type("weight_decay", weight_decay, [float])
validator.check("weight_decay", weight_decay, "", 0.0, Rel.GE)
validator.check_value_type("weight_decay", weight_decay, [float], prim_name)
validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name)
class FTRL(Optimizer):
......@@ -94,7 +95,8 @@ class FTRL(Optimizer):
use_locking=False, loss_scale=1.0, weight_decay=0.0):
super(FTRL, self).__init__(learning_rate, params)
_check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay)
_check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay,
self.cls_name)
self.moments = self.parameters.clone(prefix="moments", init=initial_accum)
self.linear = self.parameters.clone(prefix="linear", init='zeros')
self.l1 = l1
......
......@@ -21,7 +21,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from .optimizer import Optimizer
from .. import layer
......@@ -109,23 +109,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
end_learning_rate, power, beta1, beta2, eps, weight_decay):
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs."""
validator.check_type("decay_steps", decay_steps, [int])
validator.check_type("warmup_steps", warmup_steps, [int])
validator.check_type("start_learning_rate", start_learning_rate, [float])
validator.check_type("end_learning_rate", end_learning_rate, [float])
validator.check_type("power", power, [float])
validator.check_type("beta1", beta1, [float])
validator.check_type("beta2", beta2, [float])
validator.check_type("eps", eps, [float])
validator.check_type("weight_dacay", weight_decay, [float])
validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT)
validator.check_value_type("decay_steps", decay_steps, [int], prim_name)
validator.check_value_type("warmup_steps", warmup_steps, [int], prim_name)
validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name)
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
validator.check_value_type("power", power, [float], prim_name)
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT, prim_name)
validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
class Lamb(Optimizer):
......@@ -182,7 +182,7 @@ class Lamb(Optimizer):
super(Lamb, self).__init__(start_learning_rate, params)
_check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate,
power, beta1, beta2, eps, weight_decay)
power, beta1, beta2, eps, weight_decay, self.cls_name)
# turn them to scalar when me support scalar/tensor mix operations
self.global_step = Parameter(initializer(0, [1]), name="global_step")
......
......@@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor
from mindspore import log as logger
......@@ -63,7 +63,7 @@ class Optimizer(Cell):
self.gather = None
self.assignadd = None
self.global_step = None
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT)
validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name)
else:
self.dynamic_lr = True
self.gather = P.GatherV2()
......
......@@ -14,7 +14,7 @@
# ============================================================================
"""rmsprop"""
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer
rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt")
......@@ -144,8 +144,8 @@ class RMSProp(Optimizer):
self.decay = decay
self.epsilon = epsilon
validator.check_type("use_locking", use_locking, [bool])
validator.check_type("centered", centered, [bool])
validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
validator.check_value_type("centered", centered, [bool], self.cls_name)
self.centered = centered
if centered:
self.opt = P.ApplyCenteredRMSProp(use_locking)
......
......@@ -15,7 +15,7 @@
"""sgd"""
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.parameter import Parameter
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
from .optimizer import Optimizer
sgd_opt = C.MultitypeFuncGraph("sgd_opt")
......@@ -100,7 +100,7 @@ class SGD(Optimizer):
raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening))
self.dampening = dampening
validator.check_type("nesterov", nesterov, [bool])
validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
self.nesterov = nesterov
self.opt = P.SGD(dampening, weight_decay, nesterov)
......
......@@ -19,7 +19,7 @@ import os
import json
import inspect
from mindspore._c_expression import Oplib
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
# path of built-in op info register.
BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl"
......@@ -43,7 +43,7 @@ def op_info_register(op_info):
op_info_real = json.dumps(op_info)
else:
op_info_real = op_info
validator.check_type("op_info", op_info_real, [str])
validator.check_value_type("op_info", op_info_real, [str], None)
op_lib = Oplib()
file_path = os.path.realpath(inspect.getfile(func))
# keep the path custom ops implementation.
......
......@@ -16,7 +16,7 @@
from easydict import EasyDict as edict
from .. import nn
from .._checkparam import ParamValidator as validator
from .._checkparam import Validator as validator
from .._checkparam import Rel
from ..common import dtype as mstype
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
......@@ -73,14 +73,14 @@ def _check_kwargs(key_words):
raise ValueError(f"Unsupported arg '{arg}'")
if 'cast_model_type' in key_words:
validator.check('cast_model_type', key_words['cast_model_type'],
[mstype.float16, mstype.float32], Rel.IN)
validator.check_type_name('cast_model_type', key_words['cast_model_type'],
[mstype.float16, mstype.float32], None)
if 'keep_batchnorm_fp32' in key_words:
validator.check_isinstance('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool)
validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool, None)
if 'loss_scale_manager' in key_words:
loss_scale_manager = key_words['loss_scale_manager']
if loss_scale_manager:
validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager)
validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager, None)
def _add_loss_network(network, loss_fn, cast_model_type):
......@@ -97,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
label = _mp_cast_helper(mstype.float32, label)
return self._loss_fn(F.cast(out, mstype.float32), label)
validator.check_isinstance('loss_fn', loss_fn, nn.Cell)
validator.check_value_type('loss_fn', loss_fn, nn.Cell, None)
if cast_model_type == mstype.float16:
network = WithLossCell(network, loss_fn)
else:
......@@ -126,9 +126,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If set, overwrite the level setting.
"""
validator.check_isinstance('network', network, nn.Cell)
validator.check_isinstance('optimizer', optimizer, nn.Optimizer)
validator.check('level', level, "", ['O0', 'O2'], Rel.IN)
validator.check_value_type('network', network, nn.Cell, None)
validator.check_value_type('optimizer', optimizer, nn.Optimizer, None)
validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None)
_check_kwargs(kwargs)
config = dict(_config_level[level], **kwargs)
config = edict(config)
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Loss scale manager abstract class."""
from .._checkparam import ParamValidator as validator
from .._checkparam import Validator as validator
from .._checkparam import Rel
from .. import nn
......@@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager):
if init_loss_scale < 1.0:
raise ValueError("Loss scale value should be > 1")
self.loss_scale = init_loss_scale
validator.check_integer("scale_window", scale_window, 0, Rel.GT)
validator.check_integer("scale_window", scale_window, 0, Rel.GT, self.__class__.__name__)
self.scale_window = scale_window
if scale_factor <= 0:
raise ValueError("Scale factor should be > 1")
......
......@@ -32,7 +32,7 @@ power = 0.5
class TestInputs:
def test_milestone1(self):
milestone1 = 1
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.piecewise_constant_lr(milestone1, learning_rates)
def test_milestone2(self):
......@@ -46,12 +46,12 @@ class TestInputs:
def test_learning_rates1(self):
lr = True
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.piecewise_constant_lr(milestone, lr)
def test_learning_rates2(self):
lr = [1, 2, 1]
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.piecewise_constant_lr(milestone, lr)
def test_learning_rate_type(self):
......@@ -158,7 +158,7 @@ class TestInputs:
def test_is_stair(self):
is_stair = 1
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
def test_min_lr_type(self):
......@@ -183,12 +183,12 @@ class TestInputs:
def test_power(self):
power1 = True
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1)
def test_update_decay_epoch(self):
update_decay_epoch = 1
with pytest.raises(ValueError):
with pytest.raises(TypeError):
dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch,
power, update_decay_epoch)
......
......@@ -52,7 +52,7 @@ def test_psnr_max_val_negative():
def test_psnr_max_val_bool():
max_val = True
with pytest.raises(ValueError):
with pytest.raises(TypeError):
net = PSNRNet(max_val)
def test_psnr_max_val_zero():
......
......@@ -51,7 +51,7 @@ def test_ssim_max_val_negative():
def test_ssim_max_val_bool():
max_val = True
with pytest.raises(ValueError):
with pytest.raises(TypeError):
net = SSIMNet(max_val)
def test_ssim_max_val_zero():
......@@ -92,4 +92,4 @@ def test_ssim_k1_k2_wrong_value():
with pytest.raises(ValueError):
net = SSIMNet(k2=0.0)
with pytest.raises(ValueError):
net = SSIMNet(k2=-1.0)
\ No newline at end of file
net = SSIMNet(k2=-1.0)
......@@ -594,14 +594,14 @@ test_cases_for_verify_exception = [
('MaxPool2d_ValueError_2', {
'block': (
lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"),
{'exception': ValueError},
{'exception': TypeError},
),
'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
}),
('MaxPool2d_ValueError_3', {
'block': (
lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"),
{'exception': ValueError},
{'exception': TypeError},
),
'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))],
}),
......
......@@ -38,7 +38,7 @@ def test_avgpool2d_error_input():
""" test_avgpool2d_error_input """
kernel_size = 5
stride = 2.3
with pytest.raises(ValueError):
with pytest.raises(TypeError):
nn.AvgPool2d(kernel_size, stride)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册