提交 20782294 编写于 作者: F fary86

Add prim name to error message for array_ops

上级 789edcb2
...@@ -210,7 +210,7 @@ class Validator: ...@@ -210,7 +210,7 @@ class Validator:
type_names = [] type_names = []
for t in valid_values: for t in valid_values:
type_names.append(str(t)) type_names.append(str(t))
types_info = '[' + ", ".join(type_names) + ']' types_info = '[' + ', '.join(type_names) + ']'
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},' raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},'
f' but got {elem_type}.') f' but got {elem_type}.')
return (arg_key, elem_type) return (arg_key, elem_type)
...@@ -320,224 +320,6 @@ class Validator: ...@@ -320,224 +320,6 @@ class Validator:
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@staticmethod
def equal(arg_name, arg_value, cond_str, cond):
"""Judging valid value."""
if not cond:
raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
"""This method is only used for check int values, since when compare float values,
we need consider float error."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
@staticmethod
def check_integer(arg_name, arg_value, value, rel):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_shape_length(arg_name, arg_value, value, rel):
"""Shape length judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""This method is only used for check int values,
since when compare float values, we need consider float error."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isinstance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value
@staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""Is it necessary to consider error when comparing float values."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_subclass(arg_name, type_, template_type, with_type_of=True):
"""Check whether some type is subclass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_args_tensor(args):
"""Check whether args are all tensor."""
if not isinstance(args, dict):
raise TypeError("The args should be a dict.")
for arg, value in args.items():
ParamValidator.check_subclass(arg, value, mstype.tensor)
@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isinstance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
return arg_value
@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
@staticmethod
def check_typename(arg_name, arg_type, valid_types):
"""Does it contain the _name_ attribute."""
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
@staticmethod
def check_string(arg_name, arg_value, valid_values):
"""String type judgment."""
if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value
if len(valid_values) == 1:
raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
f' but got {arg_value}.')
raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
f' but got {arg_value}.')
@staticmethod
def check_type_same(args, valid_values):
"""Determine whether the types are the same."""
name = list(args.keys())[0]
value = list(args.values())[0]
if isinstance(value, type(mstype.tensor)):
value = value.element_type()
for arg_name, arg_value in args.items():
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
if arg_value not in valid_values:
raise TypeError(f'The `{arg_name}` should be in {valid_values},'
f' but `{arg_name}` is {arg_value}.')
if arg_value != value:
raise TypeError(f'`{arg_name}` should be same as `{name}`,'
f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
@staticmethod
def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
"""Determine whether the types of two variables are the same."""
if arg1_type != arg2_type:
raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
@staticmethod
def check_value_on_integer(arg_name, arg_value, value, rel):
"""Judging integer type."""
rel_fn = Rel.get_fns(rel)
type_match = isinstance(arg_value, int)
if type_match and (not rel_fn(arg_value, value)):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_param_equal(param1_name, param1_value, param2_name, param2_value):
"""Judging the equality of parameters."""
if param1_value != param2_value:
raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
f" but got `{param1_name}` = {param1_value},"
f" `{param2_name}` = {param2_value}.")
@staticmethod
def check_const_input(arg_name, arg_value):
"""Check valid value."""
if arg_value is None:
raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
@staticmethod
def check_float_positive(arg_name, arg_value):
"""Float type judgment."""
if isinstance(arg_value, float):
if arg_value > 0:
return arg_value
raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
raise TypeError(f"`{arg_name}` must be float!")
@staticmethod
def check_pad_value_by_mode(op_name, pad_mode, padding):
"""Validate value of padding according to pad_mode"""
if pad_mode != 'pad' and padding != 0:
raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
return padding
@staticmethod
def check_empty_shape_input(arg_name, arg_value):
"""Check zeros value."""
if 0 in arg_value:
raise ValueError(f"Input `{arg_name}` cannot be empty.")
@staticmethod
def check_scalar_shape_input(arg_name, arg_value):
"""Check scalar shape input."""
if arg_value != []:
raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
def check_int(input_param): def check_int(input_param):
"""Int type judgment.""" """Int type judgment."""
if isinstance(input_param, int) and not isinstance(input_param, bool): if isinstance(input_param, int) and not isinstance(input_param, bool):
...@@ -653,30 +435,6 @@ def check_output_data(data): ...@@ -653,30 +435,6 @@ def check_output_data(data):
raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.') raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
def check_axis_type_int(axis):
"""Check axis type."""
if not isinstance(axis, int):
raise TypeError('Wrong type for axis, should be int.')
def check_axis_range(axis, rank):
"""Check axis range."""
if not -rank <= axis < rank:
raise ValueError('The axis should be in range [{}, {}),'' but got {}.'.format(-rank, rank, axis))
def check_attr_int(attr_name, attr):
"""Check int type."""
if not isinstance(attr, int):
raise TypeError("The attr {} should be int, but got {}.".format(attr_name, type(attr)))
def check_t_in_range(t):
"""Check input range."""
if t not in (mstype.float16, mstype.float32, mstype.float64, mstype.int32, mstype.int64):
raise ValueError("The param T should be (float16, float32, float64, int32, int64).")
once = _expand_tuple(1) once = _expand_tuple(1)
twice = _expand_tuple(2) twice = _expand_tuple(2)
triple = _expand_tuple(3) triple = _expand_tuple(3)
......
...@@ -175,7 +175,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { ...@@ -175,7 +175,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
UpdateAdjoint(node_adjoint); UpdateAdjoint(node_adjoint);
anfnode_to_adjoin_[morph] = node_adjoint; anfnode_to_adjoin_[morph] = node_adjoint;
if (cnode_morph->stop_gradient()) { if (cnode_morph->stop_gradient()) {
MS_LOG(WARNING) << "MapMorphism node " << morph->ToString() << " is stopped."; MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped.";
return node_adjoint; return node_adjoint;
} }
......
...@@ -19,7 +19,6 @@ from mindspore._checkparam import Validator as validator ...@@ -19,7 +19,6 @@ from mindspore._checkparam import Validator as validator
from ... import context from ... import context
from ..cell import Cell from ..cell import Cell
from ..._checkparam import Rel from ..._checkparam import Rel
from ..._checkparam import ParamValidator
class _PoolNd(Cell): class _PoolNd(Cell):
...@@ -265,11 +264,11 @@ class AvgPool1d(_PoolNd): ...@@ -265,11 +264,11 @@ class AvgPool1d(_PoolNd):
stride=1, stride=1,
pad_mode="valid"): pad_mode="valid"):
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
ParamValidator.check_type('kernel_size', kernel_size, [int,]) validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
ParamValidator.check_type('stride', stride, [int,]) validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name)
ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE) validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
ParamValidator.check_integer("stride", stride, 1, Rel.GE) validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
self.kernel_size = (1, kernel_size) self.kernel_size = (1, kernel_size)
self.stride = (1, stride) self.stride = (1, stride)
self.avg_pool = P.AvgPool(ksize=self.kernel_size, self.avg_pool = P.AvgPool(ksize=self.kernel_size,
......
...@@ -24,7 +24,7 @@ import itertools ...@@ -24,7 +24,7 @@ import itertools
import numbers import numbers
import numpy as np import numpy as np
from ..._checkparam import ParamValidator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import Tensor from ...common.tensor import Tensor
...@@ -32,12 +32,12 @@ from ..operations.math_ops import _infer_shape_reduce ...@@ -32,12 +32,12 @@ from ..operations.math_ops import _infer_shape_reduce
from .._utils import _get_concat_offset from .._utils import _get_concat_offset
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
def _check_infer_attr_reduce(axis, keep_dims): def _check_infer_attr_reduce(axis, keep_dims, prim_name):
validator.check_type('keep_dims', keep_dims, [bool]) validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
validator.check_type('axis', axis, [int, tuple]) validator.check_value_type('axis', axis, [int, tuple], prim_name)
if isinstance(axis, tuple): if isinstance(axis, tuple):
for index, value in enumerate(axis): for index, value in enumerate(axis):
validator.check_type('axis[%d]' % index, value, [int]) validator.check_value_type('axis[%d]' % index, value, [int], prim_name)
class ExpandDims(PrimitiveWithInfer): class ExpandDims(PrimitiveWithInfer):
...@@ -74,13 +74,11 @@ class ExpandDims(PrimitiveWithInfer): ...@@ -74,13 +74,11 @@ class ExpandDims(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output'])
def __infer__(self, x, axis): def __infer__(self, x, axis):
validator.check_subclass("input_x", x['dtype'], mstype.tensor) validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape']) x_shape = list(x['shape'])
axis_v = axis['value'] axis_v = axis['value']
rank = len(x_shape) rank = len(x_shape)
validator.check_const_input('axis', axis_v) validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH, self.name)
validator.check_type("axis", axis_v, [int])
validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH)
if axis_v < 0: if axis_v < 0:
axis_v = rank + 1 + axis_v axis_v = rank + 1 + axis_v
x_shape.insert(axis_v, 1) x_shape.insert(axis_v, 1)
...@@ -110,7 +108,7 @@ class DType(PrimitiveWithInfer): ...@@ -110,7 +108,7 @@ class DType(PrimitiveWithInfer):
"""init DType""" """init DType"""
def __infer__(self, x): def __infer__(self, x):
validator.check_subclass("input_x", x['dtype'], mstype.tensor) validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
out = {'shape': (), out = {'shape': (),
'dtype': mstype.type_type, 'dtype': mstype.type_type,
'value': x['dtype'].element_type()} 'value': x['dtype'].element_type()}
...@@ -144,19 +142,17 @@ class SameTypeShape(PrimitiveWithInfer): ...@@ -144,19 +142,17 @@ class SameTypeShape(PrimitiveWithInfer):
def __call__(self, x, y): def __call__(self, x, y):
"""run in PyNative mode""" """run in PyNative mode"""
if x.dtype() != y.dtype(): validator.check_subclass('x', x.dtype(), mstype.tensor, self.name)
raise TypeError(f"The {x} and {y} should be same dtype.") validator.check_subclass('y', y.dtype(), mstype.tensor, self.name)
if x.shape() != y.shape(): validator.check('x dtype', x.dtype(), 'y dtype', y.dtype(), Rel.EQ, self.name, TypeError)
raise TypeError(f"The {x} and {y} should have same shape.") validator.check('x shape', x.shape(), 'y shape', y.shape(), Rel.EQ, self.name)
return x return x
def __infer__(self, x, y): def __infer__(self, x, y):
if x['dtype'] != y['dtype']: validator.check_subclass('x', x['dtype'], mstype.tensor, self.name)
raise TypeError(f"The {x} and {y} should be same dtype," validator.check_subclass('y', y['dtype'], mstype.tensor, self.name)
f" but got {x['dtype']} {y['dtype']}.") validator.check('x dtype', x['dtype'], 'y dtype', y['dtype'], Rel.EQ, self.name, TypeError)
if x['shape'] != y['shape']: validator.check('x shape', x['shape'], 'y shape', y['shape'], Rel.EQ, self.name)
raise ValueError(f"The {x} and {y} should be same shape,"
f" but got {x['shape']} {y['shape']}.")
return x return x
...@@ -191,8 +187,8 @@ class Cast(PrimitiveWithInfer): ...@@ -191,8 +187,8 @@ class Cast(PrimitiveWithInfer):
src_type = x['dtype'] src_type = x['dtype']
dst_type = t['value'] dst_type = t['value']
validator.check_subclass("input_x", src_type, [mstype.tensor, mstype.number]) validator.check_subclass("input_x", src_type, [mstype.tensor, mstype.number], self.name)
validator.check_subclass("type", dst_type, mstype.number, with_type_of=False) validator.check_subclass("type", dst_type, mstype.number, self.name)
if isinstance(src_type, type(mstype.tensor)): if isinstance(src_type, type(mstype.tensor)):
src_type = x['dtype'].element_type() src_type = x['dtype'].element_type()
...@@ -238,8 +234,8 @@ class IsSubClass(PrimitiveWithInfer): ...@@ -238,8 +234,8 @@ class IsSubClass(PrimitiveWithInfer):
sub_type_t = sub_type['value'] sub_type_t = sub_type['value']
type_v = type_['value'] type_v = type_['value']
validator.check_type("sub_type", sub_type_t, [mstype.Type]) validator.check_value_type("sub_type", sub_type_t, [mstype.Type], self.name)
validator.check_type("type_", type_v, [mstype.Type]) validator.check_value_type("type_", type_v, [mstype.Type], self.name)
value = mstype.issubclass_(sub_type_t, type_v) value = mstype.issubclass_(sub_type_t, type_v)
...@@ -273,8 +269,8 @@ class IsInstance(PrimitiveWithInfer): ...@@ -273,8 +269,8 @@ class IsInstance(PrimitiveWithInfer):
sub_type_t = inst['dtype'] sub_type_t = inst['dtype']
type_v = type_['value'] type_v = type_['value']
validator.check_const_input("inst", inst['value']) validator.check_const_input("inst", inst['value'], self.name)
validator.check_type("type_", type_v, [mstype.Type]) validator.check_value_type("type_", type_v, [mstype.Type], self.name)
value = mstype.issubclass_(sub_type_t, type_v) value = mstype.issubclass_(sub_type_t, type_v)
...@@ -316,14 +312,13 @@ class Reshape(PrimitiveWithInfer): ...@@ -316,14 +312,13 @@ class Reshape(PrimitiveWithInfer):
def __infer__(self, x, shape): def __infer__(self, x, shape):
shape_v = shape['value'] shape_v = shape['value']
x_shp = x['shape'] x_shp = x['shape']
validator.check_subclass("x", x['dtype'], mstype.tensor) validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
validator.check_const_input("shape", shape_v) validator.check_value_type("shape", shape_v, [tuple], self.name)
validator.check_type("shape", shape_v, [tuple])
shape_v = list(shape_v) shape_v = list(shape_v)
neg_index = -1 neg_index = -1
dim_prod = 1 dim_prod = 1
for i, shp_i in enumerate(shape_v): for i, shp_i in enumerate(shape_v):
validator.check_type("shape[%d]" % i, shp_i, [int]) validator.check_value_type("shape[%d]" % i, shp_i, [int], self.name)
if shp_i == -1: if shp_i == -1:
if neg_index != -1: if neg_index != -1:
raise ValueError(f'The shape can only has one -1 at most, but {shape_v}.') raise ValueError(f'The shape can only has one -1 at most, but {shape_v}.')
...@@ -332,7 +327,7 @@ class Reshape(PrimitiveWithInfer): ...@@ -332,7 +327,7 @@ class Reshape(PrimitiveWithInfer):
dim_prod *= shp_i dim_prod *= shp_i
arr_prod = np.prod(x_shp) arr_prod = np.prod(x_shp)
if dim_prod <= 0 or arr_prod % dim_prod != 0: if dim_prod <= 0 or arr_prod % dim_prod != 0:
raise ValueError(f'The product of shape should > 0 and' raise ValueError(f'For \'{self.name}\' the product of shape should > 0 and'
f' can be divided by prod of input {arr_prod},' f' can be divided by prod of input {arr_prod},'
f' but shape {shape}, product of shape {dim_prod}.') f' but shape {shape}, product of shape {dim_prod}.')
...@@ -340,7 +335,7 @@ class Reshape(PrimitiveWithInfer): ...@@ -340,7 +335,7 @@ class Reshape(PrimitiveWithInfer):
shape_v[neg_index] = int(arr_prod / dim_prod) shape_v[neg_index] = int(arr_prod / dim_prod)
dim_prod *= shape_v[neg_index] dim_prod *= shape_v[neg_index]
if dim_prod != arr_prod: if dim_prod != arr_prod:
raise ValueError(f'The shape arg for reshape must match array''s size' raise ValueError(f'For \'{self.name}\' The shape arg for reshape must match array''s size'
f' input shape {arr_prod}, shape {dim_prod}.') f' input shape {arr_prod}, shape {dim_prod}.')
value = None value = None
...@@ -406,10 +401,10 @@ class Squeeze(PrimitiveWithInfer): ...@@ -406,10 +401,10 @@ class Squeeze(PrimitiveWithInfer):
def __init__(self, axis=()): def __init__(self, axis=()):
"""init Squeeze""" """init Squeeze"""
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_type('axis', axis, [int, tuple]) validator.check_value_type('axis', axis, [int, tuple], self.name)
if isinstance(axis, tuple): if isinstance(axis, tuple):
for item in axis: for idx, item in enumerate(axis):
validator.check_type("item", item, [int]) validator.check_value_type("axis[%d]" % idx, item, [int], self.name)
else: else:
self.axis = (axis,) self.axis = (axis,)
self.add_prim_attr("axis", (axis,)) self.add_prim_attr("axis", (axis,))
...@@ -422,14 +417,14 @@ class Squeeze(PrimitiveWithInfer): ...@@ -422,14 +417,14 @@ class Squeeze(PrimitiveWithInfer):
ret = [d for d in x_shape if d != 1] ret = [d for d in x_shape if d != 1]
else: else:
for a in axis: for a in axis:
validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH) validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH, self.name)
if x_shape[a] != 1: if x_shape[a] != 1:
raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.') raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.')
ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)] ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)]
return ret return ret
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor) validator.check_subclass("x", x_dtype, mstype.tensor, self.name)
return x_dtype return x_dtype
...@@ -467,14 +462,13 @@ class Transpose(PrimitiveWithInfer): ...@@ -467,14 +462,13 @@ class Transpose(PrimitiveWithInfer):
if len(x_shape) != len(p_value): if len(x_shape) != len(p_value):
raise ValueError('The dimension of x and perm must be equal.') raise ValueError('The dimension of x and perm must be equal.')
validator.check_const_input("perm", p_value) validator.check_value_type("p_value", p_value, [tuple], self.name)
validator.check_type("p_value", p_value, [tuple]) validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
validator.check_subclass("x_type", x_type, mstype.tensor)
tmp = list(p_value) tmp = list(p_value)
for i, dim in enumerate(p_value): for i, dim in enumerate(p_value):
validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE) validator.check_integer("perm[%d]" % i, dim, 0, Rel.GE, self.name)
validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT) validator.check_integer("perm[%d]" % i, dim, len(p_value), Rel.LT, self.name)
tmp.remove(dim) tmp.remove(dim)
if dim in tmp: if dim in tmp:
raise ValueError('The value of perm is wrong.') raise ValueError('The value of perm is wrong.')
...@@ -517,15 +511,13 @@ class GatherV2(PrimitiveWithInfer): ...@@ -517,15 +511,13 @@ class GatherV2(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
def __infer__(self, params, indices, axis): def __infer__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor) validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_subclass("indices", indices['dtype'], mstype.tensor) validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], mstype.int_) validator.check_subclass("axis", axis['dtype'], mstype.int_, self.name)
validator.check_typename("element of indices", indices['dtype'], mstype.int_type)
validator.check_const_input("axis", axis['value'])
axis_v = axis['value'] axis_v = axis['value']
params_shp = params['shape'] params_shp = params['shape']
rank = len(params_shp) rank = len(params_shp)
validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT) validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name)
if axis_v < 0: if axis_v < 0:
axis_v += rank axis_v += rank
out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:]
...@@ -564,19 +556,20 @@ class Split(PrimitiveWithInfer): ...@@ -564,19 +556,20 @@ class Split(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, axis=0, output_num=1): def __init__(self, axis=0, output_num=1):
"""init Split""" """init Split"""
validator.check_type("axis", axis, [int]) validator.check_value_type("axis", axis, [int], self.name)
validator.check_type("output_num", output_num, [int]) validator.check_value_type("output_num", output_num, [int], self.name)
self.axis = axis self.axis = axis
self.output_num = output_num self.output_num = output_num
def __infer__(self, x): def __infer__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor) validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape']) x_shape = list(x['shape'])
dim = len(x_shape) dim = len(x_shape)
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT) validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
validator.check_integer("output_num", self.output_num, 0, Rel.GT) validator.check_integer("output_num", self.output_num, 0, Rel.GT, self.name)
output_valid_check = x_shape[self.axis] % self.output_num output_valid_check = x_shape[self.axis] % self.output_num
validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ) validator.check_integer("the dimension which to split divides output_num", output_valid_check, 0, Rel.EQ,
self.name)
x_shape[self.axis] = int(x_shape[self.axis] / self.output_num) x_shape[self.axis] = int(x_shape[self.axis] / self.output_num)
out_shapes = [] out_shapes = []
out_dtypes = [] out_dtypes = []
...@@ -615,7 +608,7 @@ class Rank(PrimitiveWithInfer): ...@@ -615,7 +608,7 @@ class Rank(PrimitiveWithInfer):
"""init Rank""" """init Rank"""
def __infer__(self, x): def __infer__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor) validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
out = {'shape': None, out = {'shape': None,
'dtype': None, 'dtype': None,
'value': len(x['shape'])} 'value': len(x['shape'])}
...@@ -647,15 +640,14 @@ class TruncatedNormal(PrimitiveWithInfer): ...@@ -647,15 +640,14 @@ class TruncatedNormal(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, seed=0, dtype=mstype.float32): def __init__(self, seed=0, dtype=mstype.float32):
"""init TruncatedNormal""" """init TruncatedNormal"""
validator.check_type('seed', seed, [int]) validator.check_value_type('seed', seed, [int], self.name)
validator.check_typename('dtype', dtype, mstype.number_type) validator.check_type_same({'dtype': dtype}, mstype.number_type, self.name)
def __infer__(self, shape): def __infer__(self, shape):
shape_value = shape['value'] shape_value = shape['value']
validator.check_const_input("shape", shape_value) validator.check_value_type("shape", shape_value, [tuple], self.name)
validator.check_type("shape", shape_value, [tuple])
for i, value in enumerate(shape_value): for i, value in enumerate(shape_value):
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT) validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT, self.name)
out = {'shape': shape_value, out = {'shape': shape_value,
'dtype': mstype.tensor_type(self.dtype), 'dtype': mstype.tensor_type(self.dtype),
'value': None} 'value': None}
...@@ -687,7 +679,7 @@ class Size(PrimitiveWithInfer): ...@@ -687,7 +679,7 @@ class Size(PrimitiveWithInfer):
def __infer__(self, x): def __infer__(self, x):
size = 1 size = 1
validator.check_subclass("x", x['dtype'], mstype.tensor) validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
shp = x['shape'] shp = x['shape']
if not shp: if not shp:
size = 0 size = 0
...@@ -723,25 +715,20 @@ class Fill(PrimitiveWithInfer): ...@@ -723,25 +715,20 @@ class Fill(PrimitiveWithInfer):
"""init Fill""" """init Fill"""
def __infer__(self, dtype, dims, x): def __infer__(self, dtype, dims, x):
validator.check_const_input("type", dtype['value']) validator.check_value_type("shape", dims['value'], [tuple], self.name)
validator.check_const_input("shape", dims['value']) validator.check_value_type("value", x['value'], [numbers.Number, bool], self.name)
validator.check_const_input("value", x['value']) for idx, item in enumerate(dims['value']):
validator.check_type("shape", dims['value'], [tuple]) validator.check_integer("dims[%d]" % idx, item, 0, Rel.GT, self.name)
validator.check_type("value", x['value'], [numbers.Number, bool])
for item in dims['value']:
validator.check_type("item", item, [int])
validator.check_integer("item", item, 0, Rel.GT)
x_dtype = dtype['value']
valid_types = [mstype.bool_, mstype.int8, mstype.int32, mstype.int64, valid_types = [mstype.bool_, mstype.int8, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64, mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64] mstype.float16, mstype.float32, mstype.float64]
validator.check_typename("value", x_dtype, valid_types) validator.check_type_same({"value": dtype['value']}, valid_types, self.name)
x_nptype = mstype.dtype_to_nptype(x_dtype) x_nptype = mstype.dtype_to_nptype(dtype['value'])
ret = np.full(dims['value'], x['value'], x_nptype) ret = np.full(dims['value'], x['value'], x_nptype)
out = { out = {
'value': Tensor(ret), 'value': Tensor(ret),
'shape': dims['value'], 'shape': dims['value'],
'dtype': x_dtype, 'dtype': x['dtype'],
} }
return out return out
...@@ -772,8 +759,7 @@ class OnesLike(PrimitiveWithInfer): ...@@ -772,8 +759,7 @@ class OnesLike(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor) validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_typename('x_dtype', x_dtype, mstype.number_type + (mstype.bool_,))
return x_dtype return x_dtype
...@@ -804,8 +790,7 @@ class ZerosLike(PrimitiveWithInfer): ...@@ -804,8 +790,7 @@ class ZerosLike(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("x", x_dtype, mstype.tensor) validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name)
validator.check_typename('x_dtype', x_dtype, mstype.number_type + (mstype.bool_,))
return x_dtype return x_dtype
...@@ -830,14 +815,13 @@ class TupleToArray(PrimitiveWithInfer): ...@@ -830,14 +815,13 @@ class TupleToArray(PrimitiveWithInfer):
"""init TupleToArray""" """init TupleToArray"""
def infer_value(self, x): def infer_value(self, x):
validator.check_const_input("x", x) validator.check_value_type("x", x, [tuple], self.name)
validator.check_type("x", x, [tuple]) validator.check("size of x", len(x), '', 0, Rel.GT, self.name)
validator.check("size of x", len(x), '', 0, Rel.GT)
dtype = type(x[0]) dtype = type(x[0])
for i, item in enumerate(x): for i, item in enumerate(x):
validator.check_type(f"x[{i}]", item, [numbers.Number]) validator.check_value_type(f"x[{i}]", item, [numbers.Number], self.name)
if not all(isinstance(item, dtype) for item in x): if not all(isinstance(item, dtype) for item in x):
raise TypeError("All elements of input x must be have same type.") raise TypeError("For \'{self.name}\' all elements of input x must be have same type.")
if isinstance(x[0], int): if isinstance(x[0], int):
ret = np.array(x, np.int32) ret = np.array(x, np.int32)
else: else:
...@@ -867,8 +851,7 @@ class ScalarToArray(PrimitiveWithInfer): ...@@ -867,8 +851,7 @@ class ScalarToArray(PrimitiveWithInfer):
pass pass
def infer_value(self, x): def infer_value(self, x):
validator.check_const_input("x", x) validator.check_value_type("x", x, [int, float], self.name)
validator.check_type("x", x, [int, float])
if isinstance(x, int): if isinstance(x, int):
ret = np.array(x, np.int32) ret = np.array(x, np.int32)
else: else:
...@@ -899,9 +882,8 @@ class ScalarToTensor(PrimitiveWithInfer): ...@@ -899,9 +882,8 @@ class ScalarToTensor(PrimitiveWithInfer):
pass pass
def infer_value(self, x, dtype=mstype.float32): def infer_value(self, x, dtype=mstype.float32):
validator.check_const_input("x", x) validator.check_value_type("x", x, [int, float], self.name)
validator.check_type("x", x, [int, float]) validator.check_subclass("dtype", dtype, mstype.number, self.name)
validator.check_subclass("dtype", dtype, mstype.number, with_type_of=False)
data_type = mstype.dtype_to_nptype(dtype) data_type = mstype.dtype_to_nptype(dtype)
return Tensor(np.array(x, data_type)) return Tensor(np.array(x, data_type))
...@@ -943,15 +925,14 @@ class InvertPermutation(PrimitiveWithInfer): ...@@ -943,15 +925,14 @@ class InvertPermutation(PrimitiveWithInfer):
def __infer__(self, x): def __infer__(self, x):
x_shp = x['shape'] x_shp = x['shape']
x_value = x['value'] x_value = x['value']
validator.check_const_input("shape", x_shp) validator.check_value_type("shape", x_shp, [tuple], self.name)
validator.check_type("shape", x_shp, [tuple])
z = [x_value[i] for i in range(len(x_value))] z = [x_value[i] for i in range(len(x_value))]
z.sort() z.sort()
y = [None]*len(x_value) y = [None]*len(x_value)
for i, value in enumerate(x_value): for i, value in enumerate(x_value):
validator.check_type("input[%d]" % i, value, [int]) validator.check_value_type("input[%d]" % i, value, [int], self.name)
validator.check(f'value', z[i], f'index', i) validator.check(f'value', z[i], f'index', i, Rel.EQ, self.name)
y[value] = i y[value] = i
z.append(value) z.append(value)
return {'shape': x_shp, return {'shape': x_shp,
...@@ -986,8 +967,8 @@ class Argmax(PrimitiveWithInfer): ...@@ -986,8 +967,8 @@ class Argmax(PrimitiveWithInfer):
def __init__(self, axis=-1, output_type=mstype.int64): def __init__(self, axis=-1, output_type=mstype.int64):
"""init Argmax""" """init Argmax"""
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_type("axis", axis, [int]) validator.check_value_type("axis", axis, [int], self.name)
validator.check_typename('output_type', output_type, [mstype.int32, mstype.int64]) validator.check_type_same({'output': output_type}, [mstype.int32, mstype.int64], self.name)
self.axis = axis self.axis = axis
self.add_prim_attr('output_type', output_type) self.add_prim_attr('output_type', output_type)
...@@ -996,14 +977,13 @@ class Argmax(PrimitiveWithInfer): ...@@ -996,14 +977,13 @@ class Argmax(PrimitiveWithInfer):
if axis is None: if axis is None:
axis = 0 axis = 0
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
axis = axis + x_rank if axis < 0 else axis axis = axis + x_rank if axis < 0 else axis
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
return ouput_shape return ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor) validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
validator.check_typename('input_x', x_dtype, [mstype.float32, mstype.float16])
return mstype.tensor_type(self.output_type) return mstype.tensor_type(self.output_type)
...@@ -1035,7 +1015,7 @@ class Argmin(PrimitiveWithInfer): ...@@ -1035,7 +1015,7 @@ class Argmin(PrimitiveWithInfer):
def __init__(self, axis=-1, output_type=mstype.int64): def __init__(self, axis=-1, output_type=mstype.int64):
"""init Argmin""" """init Argmin"""
self.init_prim_io_names(inputs=['x'], outputs=['output']) self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_type("axis", axis, [int]) validator.check_value_type("axis", axis, [int], self.name)
self.axis = axis self.axis = axis
self.add_prim_attr('output_type', output_type) self.add_prim_attr('output_type', output_type)
...@@ -1044,13 +1024,13 @@ class Argmin(PrimitiveWithInfer): ...@@ -1044,13 +1024,13 @@ class Argmin(PrimitiveWithInfer):
if axis is None: if axis is None:
axis = 0 axis = 0
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
axis = axis + x_rank if axis < 0 else axis axis = axis + x_rank if axis < 0 else axis
ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis]
return ouput_shape return ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor) validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
return mstype.tensor_type(self.output_type) return mstype.tensor_type(self.output_type)
...@@ -1087,17 +1067,17 @@ class ArgMaxWithValue(PrimitiveWithInfer): ...@@ -1087,17 +1067,17 @@ class ArgMaxWithValue(PrimitiveWithInfer):
"""init ArgMaxWithValue""" """init ArgMaxWithValue"""
self.axis = axis self.axis = axis
self.keep_dims = keep_dims self.keep_dims = keep_dims
_check_infer_attr_reduce(axis, keep_dims) _check_infer_attr_reduce(axis, keep_dims, self.name)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor) validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
return mstype.tensor_type(mstype.int32), x_dtype return mstype.tensor_type(mstype.int32), x_dtype
...@@ -1133,17 +1113,17 @@ class ArgMinWithValue(PrimitiveWithInfer): ...@@ -1133,17 +1113,17 @@ class ArgMinWithValue(PrimitiveWithInfer):
"""init ArgMinWithValue""" """init ArgMinWithValue"""
self.axis = axis self.axis = axis
self.keep_dims = keep_dims self.keep_dims = keep_dims
_check_infer_attr_reduce(axis, keep_dims) _check_infer_attr_reduce(axis, keep_dims, self.name)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor) validator.check_subclass("input_x", x_dtype, mstype.tensor, self.name)
return mstype.tensor_type(mstype.int32), x_dtype return mstype.tensor_type(mstype.int32), x_dtype
...@@ -1183,13 +1163,11 @@ class Tile(PrimitiveWithInfer): ...@@ -1183,13 +1163,11 @@ class Tile(PrimitiveWithInfer):
def __infer__(self, x, multiples): def __infer__(self, x, multiples):
multiples_v = multiples['value'] multiples_v = multiples['value']
x_shp = x['shape'] x_shp = x['shape']
validator.check_const_input("shape", multiples_v) validator.check_value_type("shape", multiples_v, [tuple], self.name)
validator.check_type("shape", multiples_v, [tuple])
for i, multiple in enumerate(multiples_v): for i, multiple in enumerate(multiples_v):
validator.check_type("multiples[%d]" % i, multiple, [int]) validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
validator.check_typename('x', x['dtype'], valid_types = [mstype.int16, mstype.int32, mstype.bool_, mstype.float16, mstype.float32]
[mstype.int16, mstype.int32, mstype.bool_, validator.check_tensor_type_same({'x': x['dtype']}, valid_types, self.name)
mstype.float16, mstype.float32])
len_sub = len(multiples_v) - len(x_shp) len_sub = len(multiples_v) - len(x_shp)
multiples_w = None multiples_w = None
if len_sub == 0: if len_sub == 0:
...@@ -1199,7 +1177,8 @@ class Tile(PrimitiveWithInfer): ...@@ -1199,7 +1177,8 @@ class Tile(PrimitiveWithInfer):
x_shp.insert(0, 1) x_shp.insert(0, 1)
multiples_w = multiples_v multiples_w = multiples_v
elif len_sub < 0: elif len_sub < 0:
raise ValueError("The length of multiples can not be smaller than the length of dimension in input_x.") raise ValueError(f'For \'{self.name}\' the length of multiples can not be smaller than '
f'the length of dimension in input_x.')
for i, a in enumerate(multiples_w): for i, a in enumerate(multiples_w):
x_shp[i] *= a x_shp[i] *= a
value = None value = None
...@@ -1246,23 +1225,23 @@ class UnsortedSegmentSum(PrimitiveWithInfer): ...@@ -1246,23 +1225,23 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
def __infer__(self, x, segment_ids, num_segments): def __infer__(self, x, segment_ids, num_segments):
x_type = x['dtype'] x_type = x['dtype']
x_shp = x['shape'] x_shp = x['shape']
validator.check_subclass("input_x", x_type, mstype.tensor) validator.check_subclass("input_x", x_type, mstype.tensor, self.name)
validator.check_type("x_shape", x_shp, [list]) validator.check_value_type("x_shape", x_shp, [list], self.name)
x_shp_len = len(x_shp) x_shp_len = len(x_shp)
validator.check_integer("rank of input_x", x_shp_len, 0, Rel.GT) validator.check_integer("rank of input_x", x_shp_len, 0, Rel.GT, self.name)
segment_ids_shp = segment_ids['shape'] segment_ids_shp = segment_ids['shape']
segment_ids_type = segment_ids['dtype'] segment_ids_type = segment_ids['dtype']
validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor) validator.check_subclass("segment_ids", segment_ids_type, mstype.tensor, self.name)
validator.check_type("segment_ids", segment_ids_shp, [list]) validator.check_value_type("segment_ids", segment_ids_shp, [list], self.name)
segment_ids_shp_len = len(segment_ids_shp) segment_ids_shp_len = len(segment_ids_shp)
validator.check_integer("rank of segment_ids", segment_ids_shp_len, 0, Rel.GT) validator.check_integer("rank of segment_ids", segment_ids_shp_len, 0, Rel.GT, self.name)
validator.check(f'rank of input_x', len(x_shp), validator.check(f'rank of input_x', len(x_shp),
'rank of segments_id', len(segment_ids_shp), Rel.GE) 'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
for i, value in enumerate(segment_ids_shp): for i, value in enumerate(segment_ids_shp):
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i]) validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
num_segments_v = num_segments['value'] num_segments_v = num_segments['value']
validator.check_type('num_segments', num_segments_v, [int]) validator.check_value_type('num_segments', num_segments_v, [int], self.name)
validator.check_integer("num_segments", num_segments_v, 0, Rel.GT) validator.check_integer("num_segments", num_segments_v, 0, Rel.GT, self.name)
shp = [num_segments_v] shp = [num_segments_v]
shp += x_shp[segment_ids_shp_len:] shp += x_shp[segment_ids_shp_len:]
out = {'shape': shp, out = {'shape': shp,
...@@ -1306,7 +1285,7 @@ class Concat(PrimitiveWithInfer): ...@@ -1306,7 +1285,7 @@ class Concat(PrimitiveWithInfer):
def __init__(self, axis=0): def __init__(self, axis=0):
"""init Tile""" """init Tile"""
self.__setattr_flag__ = True self.__setattr_flag__ = True
validator.check_type("axis", axis, [int]) validator.check_value_type("axis", axis, [int], self.name)
def __infer__(self, input_x): def __infer__(self, input_x):
axis = self.axis axis = self.axis
...@@ -1323,25 +1302,25 @@ class Concat(PrimitiveWithInfer): ...@@ -1323,25 +1302,25 @@ class Concat(PrimitiveWithInfer):
return out return out
def _get_pack_shape(x_shape, x_type, axis): def _get_pack_shape(x_shape, x_type, axis, prim_name):
"""for pack output shape""" """for pack output shape"""
validator.check_type("shape", x_shape, [tuple, list]) validator.check_value_type("shape", x_shape, [tuple, list], prim_name)
validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT) validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT, prim_name)
validator.check_subclass("shape0", x_type[0], mstype.tensor) validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT) validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT, prim_name)
rank_base = len(x_shape[0]) rank_base = len(x_shape[0])
N = len(x_shape) N = len(x_shape)
out_shape = x_shape[0] out_shape = x_shape[0]
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
if axis < 0: if axis < 0:
axis = axis + rank_base + 1 axis = axis + rank_base + 1
for i in range(1, N): for i in range(1, N):
v = x_shape[i] v = x_shape[i]
validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base) validator.check('len of x_shape[%d]' % i, len(v), 'len of rank_base', rank_base, Rel.EQ, prim_name)
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0], Rel.EQ, prim_name)
for j in range(rank_base): for j in range(rank_base):
if v[j] != x_shape[0][j]: if v[j] != x_shape[0][j]:
raise ValueError("Pack evaluator element %d shape in input can not pack with first element" % i) raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not pack with first element")
out_shape.insert(axis, N) out_shape.insert(axis, N)
return out_shape return out_shape
...@@ -1376,14 +1355,14 @@ class Pack(PrimitiveWithInfer): ...@@ -1376,14 +1355,14 @@ class Pack(PrimitiveWithInfer):
def __init__(self, axis=0): def __init__(self, axis=0):
"""init Pack""" """init Pack"""
self.__setattr_flag__ = True self.__setattr_flag__ = True
validator.check_type("axis", axis, [int]) validator.check_value_type("axis", axis, [int], self.name)
self.axis = axis self.axis = axis
def __infer__(self, value): def __infer__(self, value):
x_shape = value['shape'] x_shape = value['shape']
x_type = value['dtype'] x_type = value['dtype']
self.add_prim_attr('num', len(x_shape)) self.add_prim_attr('num', len(x_shape))
all_shape = _get_pack_shape(x_shape, x_type, self.axis) all_shape = _get_pack_shape(x_shape, x_type, self.axis, self.name)
out = {'shape': all_shape, out = {'shape': all_shape,
'dtype': x_type[0], 'dtype': x_type[0],
'value': None} 'value': None}
...@@ -1429,22 +1408,23 @@ class Unpack(PrimitiveWithInfer): ...@@ -1429,22 +1408,23 @@ class Unpack(PrimitiveWithInfer):
def __init__(self, axis=0): def __init__(self, axis=0):
"""init Unpack""" """init Unpack"""
self.__setattr_flag__ = True self.__setattr_flag__ = True
validator.check_type("axis", axis, [int]) validator.check_value_type("axis", axis, [int], self.name)
self.axis = axis self.axis = axis
def __infer__(self, x): def __infer__(self, x):
validator.check_subclass("x", x['dtype'], mstype.tensor) validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
x_shape = list(x['shape']) x_shape = list(x['shape'])
dim = len(x_shape) dim = len(x_shape)
validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT) validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name)
if self.axis < 0: if self.axis < 0:
self.axis = self.axis + dim self.axis = self.axis + dim
output_num = x_shape[self.axis] output_num = x_shape[self.axis]
validator.check_type("num", output_num, [int]) validator.check_value_type("num", output_num, [int], self.name)
validator.check_integer("output_num", output_num, 0, Rel.GT) validator.check_integer("output_num", output_num, 0, Rel.GT, self.name)
self.add_prim_attr('num', output_num) self.add_prim_attr('num', output_num)
output_valid_check = x_shape[self.axis] - output_num output_valid_check = x_shape[self.axis] - output_num
validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ) validator.check_integer("The dimension which to unpack divides output_num", output_valid_check, 0, Rel.EQ,
self.name)
out_shapes = [] out_shapes = []
out_dtypes = [] out_dtypes = []
out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:] out_shape = x_shape[:self.axis] + x_shape[self.axis + 1:]
...@@ -1486,8 +1466,8 @@ class Slice(PrimitiveWithInfer): ...@@ -1486,8 +1466,8 @@ class Slice(PrimitiveWithInfer):
def __infer__(self, x, begin, size): def __infer__(self, x, begin, size):
x_shape = x['shape'] x_shape = x['shape']
x_shp_len = len(x_shape) x_shp_len = len(x_shape)
validator.check_const_input('begin', begin['value']) validator.check_const_input('begin', begin['value'], self.name)
validator.check_const_input('size', size['value']) validator.check_const_input('size', size['value'], self.name)
begin_v, size_v = begin['value'], size['value'] begin_v, size_v = begin['value'], size['value']
if begin_v is None or size_v is None: if begin_v is None or size_v is None:
return {'shape': None, return {'shape': None,
...@@ -1499,7 +1479,8 @@ class Slice(PrimitiveWithInfer): ...@@ -1499,7 +1479,8 @@ class Slice(PrimitiveWithInfer):
for i in range(x_shp_len): for i in range(x_shp_len):
if x_shape[i] < begin_v[i] + size_v[i]: if x_shape[i] < begin_v[i] + size_v[i]:
y = begin_v[i] + size_v[i] y = begin_v[i] + size_v[i]
raise ValueError("Slice shape can not bigger than orign shape %d, %d." % (x_shape[i], y)) raise ValueError("For '%s' slice shape can not bigger than orign shape %d, %d." %
(self.name, x_shape[i], y))
return {'shape': size_v, return {'shape': size_v,
'dtype': x['dtype'], 'dtype': x['dtype'],
'value': None} 'value': None}
...@@ -1565,11 +1546,11 @@ class Select(PrimitiveWithInfer): ...@@ -1565,11 +1546,11 @@ class Select(PrimitiveWithInfer):
def infer_dtype(self, cond_type, x_type, y_type): def infer_dtype(self, cond_type, x_type, y_type):
self.add_prim_attr('T', x_type) self.add_prim_attr('T', x_type)
validator.check_subclass("x_type", x_type, mstype.tensor) validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
validator.check_subclass("y_type", y_type, mstype.tensor) validator.check_subclass("y_type", y_type, mstype.tensor, self.name)
validator.check_typename("cond_type", cond_type, [mstype.bool_]) validator.check_tensor_type_same({"cond": cond_type}, [mstype.bool_], self.name)
if x_type != y_type: if x_type != y_type:
raise TypeError('The x_type %s must be the same as y_type %s.' % (x_type, y_type)) raise TypeError('\'%s\' the x_type %s must be the same as y_type %s.' % (self.name, x_type, y_type))
return x_type return x_type
...@@ -1637,26 +1618,23 @@ class StridedSlice(PrimitiveWithInfer): ...@@ -1637,26 +1618,23 @@ class StridedSlice(PrimitiveWithInfer):
shrink_axis_mask=0): shrink_axis_mask=0):
"""init StrideSlice""" """init StrideSlice"""
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
validator.check_type('begin_mask', begin_mask, [int]) validator.check_value_type('begin_mask', begin_mask, [int], self.name)
validator.check_type('end_mask', end_mask, [int]) validator.check_value_type('end_mask', end_mask, [int], self.name)
validator.check_type('ellipsis_mask', ellipsis_mask, [int]) validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
validator.check_type('new_axis_mask', new_axis_mask, [int]) validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int]) validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
def __infer__(self, x, begin, end, strides): def __infer__(self, x, begin, end, strides):
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value'] begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
validator.check_const_input("begin", begin_v) validator.check_value_type("begin", begin_v, [tuple], self.name)
validator.check_const_input("end", end_v) validator.check_value_type("end", end_v, [tuple], self.name)
validator.check_const_input("strides", strides_v) validator.check_value_type("strides", strides_v, [tuple], self.name)
validator.check_type("begin", begin_v, [tuple])
validator.check_type("end", end_v, [tuple])
validator.check_type("strides", strides_v, [tuple])
x_shape = x['shape'] x_shape = x['shape']
x_shp_len = len(x_shape) x_shp_len = len(x_shape)
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len: if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
raise ValueError(f"The length of begin index{begin_v}, end index{end_v} and strides{strides_v} " raise ValueError(f"For \'{self.name}\' the length of begin index{begin_v}, end index{end_v} and "
f"must be equal to the dims({x_shp_len}) of input.") f"strides{strides_v} must be equal to the dims({x_shp_len}) of input.")
ret_shape = [] ret_shape = []
append_dimensions = [] append_dimensions = []
...@@ -1669,8 +1647,8 @@ class StridedSlice(PrimitiveWithInfer): ...@@ -1669,8 +1647,8 @@ class StridedSlice(PrimitiveWithInfer):
append_dimensions.append(x_shape[x_shp_len - 1 - len(append_dimensions)]) append_dimensions.append(x_shape[x_shp_len - 1 - len(append_dimensions)])
continue continue
if i < (len(shrink_pos) - 2) and shrink_pos[i] == '1': if i < (len(shrink_pos) - 2) and shrink_pos[i] == '1':
validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE) validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE, self.name)
validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT) validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT, self.name)
continue continue
begin_idx = begin_v[i] begin_idx = begin_v[i]
...@@ -1680,9 +1658,9 @@ class StridedSlice(PrimitiveWithInfer): ...@@ -1680,9 +1658,9 @@ class StridedSlice(PrimitiveWithInfer):
begin_idx = 0 begin_idx = 0
if self.end_mask: if self.end_mask:
end_idx = x_shape[i] end_idx = x_shape[i]
validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE) validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE, self.name)
validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE) validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE, self.name)
validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE) validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE, self.name)
if strides_idx > 0: if strides_idx > 0:
# If sliced forward , end_idx >= begin_idx # If sliced forward , end_idx >= begin_idx
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.LE) validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.LE)
...@@ -1736,7 +1714,7 @@ class Diag(PrimitiveWithInfer): ...@@ -1736,7 +1714,7 @@ class Diag(PrimitiveWithInfer):
"""init Diag""" """init Diag"""
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass('input_x', x_type, mstype.tensor) validator.check_subclass('input_x', x_type, mstype.tensor, self.name)
return x_type return x_type
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
...@@ -1748,7 +1726,7 @@ class Diag(PrimitiveWithInfer): ...@@ -1748,7 +1726,7 @@ class Diag(PrimitiveWithInfer):
def infer_value(self, x): def infer_value(self, x):
if x is None: if x is None:
return None return None
validator.check("input x rank", len(x.shape()), "", 1) validator.check_integer("input x rank", len(x.shape()), 1, Rel.EQ, self.name)
ret = np.diag(x.asnumpy()) ret = np.diag(x.asnumpy())
return Tensor(ret) return Tensor(ret)
...@@ -1783,13 +1761,13 @@ class DiagPart(PrimitiveWithInfer): ...@@ -1783,13 +1761,13 @@ class DiagPart(PrimitiveWithInfer):
"""init DiagPart""" """init DiagPart"""
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass('input_x', x_type, mstype.tensor) validator.check_subclass('input_x', x_type, mstype.tensor, self.name)
return x_type return x_type
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
if len(x_shape)%2 != 0 or \ if len(x_shape)%2 != 0 or \
not x_shape: not x_shape:
raise ValueError(f"DiagPart input rank must be non-zero and even, but got rank {len(x_shape)}, " raise ValueError(f"For \'{self.name}\' input rank must be non-zero and even, but got rank {len(x_shape)}, "
f"with shapes {x_shape}") f"with shapes {x_shape}")
length = len(x_shape) // 2 length = len(x_shape) // 2
ret_shape = x_shape[0:length] ret_shape = x_shape[0:length]
...@@ -1798,7 +1776,7 @@ class DiagPart(PrimitiveWithInfer): ...@@ -1798,7 +1776,7 @@ class DiagPart(PrimitiveWithInfer):
def infer_value(self, x): def infer_value(self, x):
if x is None: if x is None:
return None return None
validator.check("x rank", len(x.shape()), "", 2) validator.check("x rank", len(x.shape()), "", 2, Rel.EQ, self.name)
ret = np.diag(x.asnumpy()) ret = np.diag(x.asnumpy())
return Tensor(ret) return Tensor(ret)
...@@ -1826,12 +1804,10 @@ class Eye(PrimitiveWithInfer): ...@@ -1826,12 +1804,10 @@ class Eye(PrimitiveWithInfer):
"""init Eye""" """init Eye"""
def infer_value(self, n, m, t): def infer_value(self, n, m, t):
validator.check_type("n", n, [int]) validator.check_integer("n", n, 0, Rel.GT, self.name)
validator.check_integer("n", n, 0, Rel.GT) validator.check_integer("m", m, 0, Rel.GT, self.name)
validator.check_type("m", m, [int])
validator.check_integer("m", m, 0, Rel.GT)
args = {"dtype": t} args = {"dtype": t}
validator.check_type_same(args, mstype.number_type + (mstype.bool_,)) validator.check_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
np_type = mstype.dtype_to_nptype(t) np_type = mstype.dtype_to_nptype(t)
ret = np.eye(n, m, dtype=np_type) ret = np.eye(n, m, dtype=np_type)
return Tensor(ret) return Tensor(ret)
...@@ -1866,16 +1842,15 @@ class ScatterNd(PrimitiveWithInfer): ...@@ -1866,16 +1842,15 @@ class ScatterNd(PrimitiveWithInfer):
def __infer__(self, indices, update, shape): def __infer__(self, indices, update, shape):
shp = shape['value'] shp = shape['value']
validator.check_subclass("indices_dtype", indices['dtype'], mstype.tensor) validator.check_subclass("update_dtype", update['dtype'], mstype.tensor, self.name)
validator.check_subclass("update_dtype", update['dtype'], mstype.tensor) validator.check_tensor_type_same({"indices": indices['dtype']}, mstype.int_type, self.name)
validator.check_typename("indices_dtype", indices['dtype'], mstype.int_type) validator.check_value_type("shape", shp, [tuple], self.name)
validator.check_type("shape", shp, [tuple])
for i, x in enumerate(shp): for i, x in enumerate(shp):
validator.check_integer("shape[%d]" % i, x, 0, Rel.GT) validator.check_integer("shape[%d]" % i, x, 0, Rel.GT, self.name)
indices_shape, update_shape = indices["shape"], update["shape"] indices_shape, update_shape = indices["shape"], update["shape"]
if indices_shape[0] != update_shape[0]: if indices_shape[0] != update_shape[0]:
raise ValueError('The indices_shape[0] and update_shape[0] must be equal.') raise ValueError(f'For \'{self.name}\' The indices_shape[0] and update_shape[0] must be equal.')
return {'shape': shp, return {'shape': shp,
'dtype': update['dtype'], 'dtype': update['dtype'],
...@@ -1913,7 +1888,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): ...@@ -1913,7 +1888,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) self.init_prim_io_names(inputs=['image_in'], outputs=['image_out'])
def infer_shape(self, x): def infer_shape(self, x):
validator.check('the dimension of input_x', len(x), '', 2, Rel.GE) validator.check('the dimension of input_x', len(x), '', 2, Rel.GE, self.name)
return tuple(x)[:-2] + tuple(self.size) return tuple(x)[:-2] + tuple(self.size)
def infer_dtype(self, x): def infer_dtype(self, x):
...@@ -1947,13 +1922,12 @@ class GatherNd(PrimitiveWithInfer): ...@@ -1947,13 +1922,12 @@ class GatherNd(PrimitiveWithInfer):
def infer_shape(self, x_shape, indices_shape): def infer_shape(self, x_shape, indices_shape):
validator.check('the dimension of x', len(x_shape), validator.check('the dimension of x', len(x_shape),
'the dimension of indices', indices_shape[-1], Rel.GE) 'the dimension of indices', indices_shape[-1], Rel.GE, self.name)
return indices_shape[:-1] + x_shape[indices_shape[-1]:] return indices_shape[:-1] + x_shape[indices_shape[-1]:]
def infer_dtype(self, x_dtype, indices_dtype): def infer_dtype(self, x_dtype, indices_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor) validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
validator.check_subclass("indices_dtype", indices_dtype, mstype.tensor) validator.check_tensor_type_same({"indices": indices_dtype}, mstype.int_type, self.name)
validator.check_typename("indices_dtype", indices_dtype, mstype.int_type)
return x_dtype return x_dtype
...@@ -1995,12 +1969,9 @@ class ScatterNdUpdate(PrimitiveWithInfer): ...@@ -1995,12 +1969,9 @@ class ScatterNdUpdate(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor) validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name)
validator.check_subclass("indices_dtype", indices_dtype, mstype.tensor) args = {"x": x_dtype, "value": value_dtype}
validator.check_subclass("value_dtype", value_dtype, mstype.tensor) validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_typename('indices_dtype', indices_dtype, mstype.int_type)
args = {"x_dtype": x_dtype, "value_dtype": value_dtype}
validator.check_type_same(args, (mstype.bool_,) + mstype.number_type)
return x_dtype return x_dtype
...@@ -2038,7 +2009,7 @@ class SpaceToDepth(PrimitiveWithInfer): ...@@ -2038,7 +2009,7 @@ class SpaceToDepth(PrimitiveWithInfer):
def __init__(self, block_size): def __init__(self, block_size):
"""Init SpaceToDepth""" """Init SpaceToDepth"""
self.init_prim_io_names(inputs=['x'], outputs=['y']) self.init_prim_io_names(inputs=['x'], outputs=['y'])
validator.check_type('block_size', block_size, [int]) validator.check_value_type('block_size', block_size, [int], self.name)
validator.check('block_size', block_size, '', 2, Rel.GE) validator.check('block_size', block_size, '', 2, Rel.GE)
self.block_size = block_size self.block_size = block_size
self.add_prim_attr("data_format", "NCHW") self.add_prim_attr("data_format", "NCHW")
...@@ -2048,7 +2019,7 @@ class SpaceToDepth(PrimitiveWithInfer): ...@@ -2048,7 +2019,7 @@ class SpaceToDepth(PrimitiveWithInfer):
out_shape = copy.deepcopy(x_shape) out_shape = copy.deepcopy(x_shape)
for i in range(2): for i in range(2):
if out_shape[i+2] % self.block_size != 0: if out_shape[i+2] % self.block_size != 0:
raise ValueError(f'SpaceToDepth input shape[{i+2}] {out_shape[i+2]} should be ' raise ValueError(f'For \'{self.name}\' input shape[{i+2}] {out_shape[i+2]} should be '
f'fully divided by block_size {self.block_size}') f'fully divided by block_size {self.block_size}')
out_shape[i+2] //= self.block_size out_shape[i+2] //= self.block_size
...@@ -2056,7 +2027,7 @@ class SpaceToDepth(PrimitiveWithInfer): ...@@ -2056,7 +2027,7 @@ class SpaceToDepth(PrimitiveWithInfer):
return out_shape return out_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor) validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
return x_dtype return x_dtype
...@@ -2096,8 +2067,8 @@ class DepthToSpace(PrimitiveWithInfer): ...@@ -2096,8 +2067,8 @@ class DepthToSpace(PrimitiveWithInfer):
def __init__(self, block_size): def __init__(self, block_size):
"""Init DepthToSpace""" """Init DepthToSpace"""
self.init_prim_io_names(inputs=['x'], outputs=['y']) self.init_prim_io_names(inputs=['x'], outputs=['y'])
validator.check_type('block_size', block_size, [int]) validator.check_value_type('block_size', block_size, [int], self.name)
validator.check('block_size', block_size, '', 2, Rel.GE) validator.check('block_size', block_size, '', 2, Rel.GE, self.name)
self.block_size = block_size self.block_size = block_size
self.add_prim_attr("data_format", "NCHW") self.add_prim_attr("data_format", "NCHW")
...@@ -2107,12 +2078,13 @@ class DepthToSpace(PrimitiveWithInfer): ...@@ -2107,12 +2078,13 @@ class DepthToSpace(PrimitiveWithInfer):
for i in range(2): for i in range(2):
out_shape[i+2] *= self.block_size out_shape[i+2] *= self.block_size
validator.check('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size), '', 0) validator.check_integer('x_shape[1] % (block_size*block_size)', x_shape[1] % (self.block_size*self.block_size),
0, Rel.EQ, self.name)
out_shape[1] //= self.block_size * self.block_size out_shape[1] //= self.block_size * self.block_size
return out_shape return out_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("x_dtype", x_dtype, mstype.tensor) validator.check_subclass("x_dtype", x_dtype, mstype.tensor, self.name)
return x_dtype return x_dtype
...@@ -2159,27 +2131,26 @@ class SpaceToBatch(PrimitiveWithInfer): ...@@ -2159,27 +2131,26 @@ class SpaceToBatch(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, block_size, paddings): def __init__(self, block_size, paddings):
"""Init SpaceToBatch""" """Init SpaceToBatch"""
validator.check_type('block_size', block_size, [int]) validator.check_value_type('block_size', block_size, [int], self.name)
validator.check('block_size', block_size, '', 1, Rel.GT) validator.check('block_size', block_size, '', 1, Rel.GT, self.name)
self.block_size = block_size self.block_size = block_size
validator.check('paddings shape', np.array(paddings).shape, '', (2, 2)) validator.check('paddings shape', np.array(paddings).shape, '', (2, 2), Rel.EQ, self.name)
for elem in itertools.chain(*paddings): for elem in itertools.chain(*paddings):
validator.check_type('paddings element', elem, [int]) validator.check_value_type('paddings element', elem, [int], self.name)
self.paddings = paddings self.paddings = paddings
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor) validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_typename('input_x', x_dtype, mstype.number_type)
return x_dtype return x_dtype
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check('rank of input_x', len(x_shape), '', 4) validator.check_integer('rank of input_x', len(x_shape), 4, Rel.EQ, self.name)
out_shape = copy.deepcopy(x_shape) out_shape = copy.deepcopy(x_shape)
for i in range(2): for i in range(2):
padded = out_shape[i+2] + self.paddings[i][0] + \ padded = out_shape[i+2] + self.paddings[i][0] + \
self.paddings[i][1] self.paddings[i][1]
if padded % self.block_size != 0: if padded % self.block_size != 0:
raise ValueError(f'padded[{i}] {padded} should be divisible by ' raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
f'block_size {self.block_size}') f'block_size {self.block_size}')
out_shape[i+2] = padded // self.block_size out_shape[i+2] = padded // self.block_size
out_shape[0] *= self.block_size * self.block_size out_shape[0] *= self.block_size * self.block_size
...@@ -2227,17 +2198,16 @@ class BatchToSpace(PrimitiveWithInfer): ...@@ -2227,17 +2198,16 @@ class BatchToSpace(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, block_size, crops): def __init__(self, block_size, crops):
"""Init BatchToSpace""" """Init BatchToSpace"""
validator.check_type('block_size', block_size, [int]) validator.check_value_type('block_size', block_size, [int], self.name)
validator.check('block_size', block_size, '', 1, Rel.GT) validator.check('block_size', block_size, '', 1, Rel.GT, self.name)
self.block_size = block_size self.block_size = block_size
validator.check('crops shape', np.array(crops).shape, '', (2, 2)) validator.check('crops shape', np.array(crops).shape, '', (2, 2))
for elem in itertools.chain(*crops): for elem in itertools.chain(*crops):
validator.check_type('crops element', elem, [int]) validator.check_value_type('crops element', elem, [int], self.name)
self.crops = crops self.crops = crops
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_subclass("input_x", x_dtype, mstype.tensor) validator.check_tensor_type_same({'input_x': x_dtype}, mstype.number_type, self.name)
validator.check_typename('input_x', x_dtype, mstype.number_type)
return x_dtype return x_dtype
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
...@@ -2246,11 +2216,11 @@ class BatchToSpace(PrimitiveWithInfer): ...@@ -2246,11 +2216,11 @@ class BatchToSpace(PrimitiveWithInfer):
for i in range(2): for i in range(2):
x_block_prod = out_shape[i+2] * self.block_size x_block_prod = out_shape[i+2] * self.block_size
crops_sum = self.crops[i][0] + self.crops[i][1] crops_sum = self.crops[i][0] + self.crops[i][1]
validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT) validator.check("x block shape prod", x_block_prod, 'crops sum', crops_sum, Rel.GT, self.name)
out_shape[i+2] = x_block_prod - crops_sum out_shape[i+2] = x_block_prod - crops_sum
block_size_prod = self.block_size * self.block_size block_size_prod = self.block_size * self.block_size
if out_shape[0] % block_size_prod != 0: if out_shape[0] % block_size_prod != 0:
raise ValueError(f'input_x dimension 0 {out_shape[0]} should be divisible by ' raise ValueError(f'For \'{self.name}\' input_x dimension 0 {out_shape[0]} should be divisible by '
f'block_size_prod {block_size_prod}') f'block_size_prod {block_size_prod}')
out_shape[0] = out_shape[0] // block_size_prod out_shape[0] = out_shape[0] // block_size_prod
return out_shape return out_shape
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test ops """
import functools
import numpy as np
from mindspore import ops
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
import mindspore.ops.composite as C
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from ..ut_filter import non_graph_engine
from mindspore.common.api import _executor
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward\
import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config,
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
from ....mindspore_test_framework.pipeline.gradient.compile_gradient\
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
class ExpandDimsNet(nn.Cell):
def __init__(self, axis):
super(ExpandDimsNet, self).__init__()
self.axis = axis
self.op = P.ExpandDims()
def construct(self, x):
return self.op(x, self.axis)
class IsInstanceNet(nn.Cell):
def __init__(self, inst):
super(IsInstanceNet, self).__init__()
self.inst = inst
self.op = P.IsInstance()
def construct(self, t):
return self.op(self.inst, t)
class ReshapeNet(nn.Cell):
def __init__(self, shape):
super(ReshapeNet, self).__init__()
self.shape = shape
self.op = P.Reshape()
def construct(self, x):
return self.op(x, self.shape)
raise_set = [
# input is scala, not Tensor
('ExpandDims0', {
'block': (P.ExpandDims(), {'exception': TypeError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [5.0, 1],
'skip': ['backward']}),
# axis is as a parameter
('ExpandDims1', {
'block': (P.ExpandDims(), {'exception': TypeError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), 1],
'skip': ['backward']}),
# axis as an attribute, but less then lower limit
('ExpandDims2', {
'block': (ExpandDimsNet(-4), {'exception': ValueError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# axis as an attribute, but greater then upper limit
('ExpandDims3', {
'block': (ExpandDimsNet(3), {'exception': ValueError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input is scala, not Tensor
('DType0', {
'block': (P.DType(), {'exception': TypeError, 'error_keywords': ['DType']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input x scala, not Tensor
('SameTypeShape0', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input y scala, not Tensor
('SameTypeShape1', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), 5.0],
'skip': ['backward']}),
# type of x and y not match
('SameTypeShape2', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.int32))],
'skip': ['backward']}),
# shape of x and y not match
('SameTypeShape3', {
'block': (P.SameTypeShape(), {'exception': ValueError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 3]).astype(np.float32))],
'skip': ['backward']}),
# sub_type is None
('IsSubClass0', {
'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
'desc_inputs': [None, mstype.number],
'skip': ['backward']}),
# type_ is None
('IsSubClass1', {
'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
'desc_inputs': [mstype.number, None],
'skip': ['backward']}),
# inst is var
('IsInstance0', {
'block': (P.IsInstance(), {'exception': ValueError, 'error_keywords': ['IsInstance']}),
'desc_inputs': [5.0, mstype.number],
'skip': ['backward']}),
# t is not mstype.Type
('IsInstance1', {
'block': (IsInstanceNet(5.0), {'exception': TypeError, 'error_keywords': ['IsInstance']}),
'desc_inputs': [None],
'skip': ['backward']}),
# input x is scalar, not Tensor
('Reshape0', {
'block': (P.Reshape(), {'exception': TypeError, 'error_keywords': ['Reshape']}),
'desc_inputs': [5.0, (1, 2)],
'skip': ['backward']}),
# input shape is var
('Reshape1', {
'block': (P.Reshape(), {'exception': TypeError, 'error_keywords': ['Reshape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), (2, 3, 2)],
'skip': ['backward']}),
# element of shape is not int
('Reshape3', {
'block': (ReshapeNet((2, 3.0, 2)), {'exception': TypeError, 'error_keywords': ['Reshape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
]
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
def test_check_exception():
return raise_set
...@@ -383,7 +383,7 @@ def test_tensor_slice_reduce_out_of_bounds_neg(): ...@@ -383,7 +383,7 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
net = NetWork() net = NetWork()
with pytest.raises(ValueError) as ex: with pytest.raises(ValueError) as ex:
net(input_tensor) net(input_tensor)
assert "The `begin[0]` should be an int and must greater or equal to -6, but got -7" in str(ex.value) assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(ex.value)
def test_tensor_slice_reduce_out_of_bounds_positive(): def test_tensor_slice_reduce_out_of_bounds_positive():
...@@ -400,4 +400,4 @@ def test_tensor_slice_reduce_out_of_bounds_positive(): ...@@ -400,4 +400,4 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
net = NetWork() net = NetWork()
with pytest.raises(ValueError) as ex: with pytest.raises(ValueError) as ex:
net(input_tensor) net(input_tensor)
assert "The `begin[0]` should be an int and must less than 6, but got 6" in str(ex.value) assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import numpy as np import numpy as np
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore._checkparam import ParamValidator as validator from mindspore._checkparam import Validator as validator
def avg_pooling(x, pool_h, pool_w, stride): def avg_pooling(x, pool_h, pool_w, stride):
...@@ -32,7 +32,7 @@ def avg_pooling(x, pool_h, pool_w, stride): ...@@ -32,7 +32,7 @@ def avg_pooling(x, pool_h, pool_w, stride):
Returns: Returns:
numpy.ndarray, an output array after applying average pooling on input array. numpy.ndarray, an output array after applying average pooling on input array.
""" """
validator.check_integer("stride", stride, 0, Rel.GT) validator.check_integer("stride", stride, 0, Rel.GT, None)
num, channel, height, width = x.shape num, channel, height, width = x.shape
out_h = (height - pool_h)//stride + 1 out_h = (height - pool_h)//stride + 1
out_w = (width - pool_w)//stride + 1 out_w = (width - pool_w)//stride + 1
...@@ -217,7 +217,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0, ...@@ -217,7 +217,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
dilation=1, groups=1, padding_mode='zeros'): dilation=1, groups=1, padding_mode='zeros'):
"""Convolution 2D.""" """Convolution 2D."""
# pylint: disable=unused-argument # pylint: disable=unused-argument
validator.check_type('stride', stride, (int, tuple)) validator.check_value_type('stride', stride, (int, tuple), None)
if isinstance(stride, int): if isinstance(stride, int):
stride = (stride, stride) stride = (stride, stride)
elif len(stride) == 4: elif len(stride) == 4:
...@@ -229,7 +229,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0, ...@@ -229,7 +229,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
f"a tuple of two positive int numbers, but got {stride}") f"a tuple of two positive int numbers, but got {stride}")
stride_h = stride[0] stride_h = stride[0]
stride_w = stride[1] stride_w = stride[1]
validator.check_type('dilation', dilation, (int, tuple)) validator.check_value_type('dilation', dilation, (int, tuple), None)
if isinstance(dilation, int): if isinstance(dilation, int):
dilation = (dilation, dilation) dilation = (dilation, dilation)
elif len(dilation) == 4: elif len(dilation) == 4:
...@@ -384,7 +384,7 @@ def matmul(x, w, b=None): ...@@ -384,7 +384,7 @@ def matmul(x, w, b=None):
def max_pooling(x, pool_h, pool_w, stride): def max_pooling(x, pool_h, pool_w, stride):
"""Max pooling.""" """Max pooling."""
validator.check_integer("stride", stride, 0, Rel.GT) validator.check_integer("stride", stride, 0, Rel.GT, None)
num, channel, height, width = x.shape num, channel, height, width = x.shape
out_h = (height - pool_h)//stride + 1 out_h = (height - pool_h)//stride + 1
out_w = (width - pool_w)//stride + 1 out_w = (width - pool_w)//stride + 1
...@@ -427,7 +427,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride): ...@@ -427,7 +427,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
def max_pool_with_argmax(x, pool_h, pool_w, stride): def max_pool_with_argmax(x, pool_h, pool_w, stride):
"""Max pooling with argmax.""" """Max pooling with argmax."""
validator.check_integer("stride", stride, 0, Rel.GT) validator.check_integer("stride", stride, 0, Rel.GT, None)
num, channel, height, width = x.shape num, channel, height, width = x.shape
out_h = (height - pool_h)//stride + 1 out_h = (height - pool_h)//stride + 1
out_w = (width - pool_w)//stride + 1 out_w = (width - pool_w)//stride + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册