提交 20782294 编写于 作者: F fary86

Add prim name to error message for array_ops

上级 789edcb2
......@@ -210,7 +210,7 @@ class Validator:
type_names = []
for t in valid_values:
type_names.append(str(t))
types_info = '[' + ", ".join(type_names) + ']'
types_info = '[' + ', '.join(type_names) + ']'
raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},'
f' but got {elem_type}.')
return (arg_key, elem_type)
......@@ -320,224 +320,6 @@ class Validator:
raise TypeError(f"{msg_prefix} `{arg_name}` must be float.")
class ParamValidator:
"""Parameter validator. NOTICE: this class will be replaced by `class Validator`"""
@staticmethod
def equal(arg_name, arg_value, cond_str, cond):
"""Judging valid value."""
if not cond:
raise ValueError(f'The `{arg_name}` must be {cond_str}, but got {arg_value}.')
@staticmethod
def check(arg_name, arg_value, value_name, value, rel=Rel.EQ):
"""This method is only used for check int values, since when compare float values,
we need consider float error."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}')
raise ValueError(f'The `{arg_name}` should be {rel_str}, but got {arg_value}.')
@staticmethod
def check_integer(arg_name, arg_value, value, rel):
"""Integer value judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_shape_length(arg_name, arg_value, value, rel):
"""Shape length judgment."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, value):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The length of `{arg_name}` should be an int and must {rel_str}, but got {arg_value}')
return arg_value
@staticmethod
def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""This method is only used for check int values,
since when compare float values, we need consider float error."""
rel_fn = Rel.get_fns(rel)
type_mismatch = not isinstance(arg_value, int)
if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be an int in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isinstance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value
@staticmethod
def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel):
"""Is it necessary to consider error when comparing float values."""
rel_fn = Rel.get_fns(rel)
if not rel_fn(arg_value, lower_limit, upper_limit):
rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit)
raise ValueError(f'The `{arg_name}` should be in range {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_subclass(arg_name, type_, template_type, with_type_of=True):
"""Check whether some type is subclass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
type_str = (type(type_).__name__ if isinstance(type_, (tuple, list)) else "") + str(type_)
raise TypeError(f'The {"type of" if with_type_of else ""} `{arg_name}` should be subclass'
f' of {",".join((str(x) for x in template_type))}, but got {type_str}.')
@staticmethod
def check_args_tensor(args):
"""Check whether args are all tensor."""
if not isinstance(args, dict):
raise TypeError("The args should be a dict.")
for arg, value in args.items():
ParamValidator.check_subclass(arg, value, mstype.tensor)
@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isinstance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
return arg_value
@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
# Notice: bool is subclass of int, so `check_type('x', True, [int])` will check fail, and
# `check_type('x', True, [bool, int])` will check pass
if isinstance(arg_value, bool) and bool not in tuple(valid_types):
raise_error_msg()
if isinstance(arg_value, tuple(valid_types)):
return arg_value
raise_error_msg()
@staticmethod
def check_typename(arg_name, arg_type, valid_types):
"""Does it contain the _name_ attribute."""
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
if isinstance(arg_type, type(mstype.tensor)):
arg_type = arg_type.element_type()
if arg_type in valid_types:
return arg_type
type_names = [get_typename(t) for t in valid_types]
if len(valid_types) == 1:
raise ValueError(f'The type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise ValueError(f'The type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
@staticmethod
def check_string(arg_name, arg_value, valid_values):
"""String type judgment."""
if isinstance(arg_value, str) and arg_value in valid_values:
return arg_value
if len(valid_values) == 1:
raise ValueError(f'The `{arg_name}` should be str and must be {valid_values[0]},'
f' but got {arg_value}.')
raise ValueError(f'The `{arg_name}` should be str and must be one of {valid_values},'
f' but got {arg_value}.')
@staticmethod
def check_type_same(args, valid_values):
"""Determine whether the types are the same."""
name = list(args.keys())[0]
value = list(args.values())[0]
if isinstance(value, type(mstype.tensor)):
value = value.element_type()
for arg_name, arg_value in args.items():
if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()
if arg_value not in valid_values:
raise TypeError(f'The `{arg_name}` should be in {valid_values},'
f' but `{arg_name}` is {arg_value}.')
if arg_value != value:
raise TypeError(f'`{arg_name}` should be same as `{name}`,'
f' but `{arg_name}` is {arg_value}, `{name}` is {value}.')
@staticmethod
def check_two_types_same(arg1_name, arg1_type, arg2_name, arg2_type):
"""Determine whether the types of two variables are the same."""
if arg1_type != arg2_type:
raise TypeError(f'The type of `{arg1_name}` and `{arg2_name}` should be same.')
@staticmethod
def check_value_on_integer(arg_name, arg_value, value, rel):
"""Judging integer type."""
rel_fn = Rel.get_fns(rel)
type_match = isinstance(arg_value, int)
if type_match and (not rel_fn(arg_value, value)):
rel_str = Rel.get_strs(rel).format(value)
raise ValueError(f'The `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.')
return arg_value
@staticmethod
def check_param_equal(param1_name, param1_value, param2_name, param2_value):
"""Judging the equality of parameters."""
if param1_value != param2_value:
raise ValueError(f"`{param1_name}` must equal `{param2_name}`,"
f" but got `{param1_name}` = {param1_value},"
f" `{param2_name}` = {param2_value}.")
@staticmethod
def check_const_input(arg_name, arg_value):
"""Check valid value."""
if arg_value is None:
raise ValueError(f'The `{arg_name}` must be a const input, but got {arg_value}.')
@staticmethod
def check_float_positive(arg_name, arg_value):
"""Float type judgment."""
if isinstance(arg_value, float):
if arg_value > 0:
return arg_value
raise ValueError(f"The `{arg_name}` must be positive, but got {arg_value}.")
raise TypeError(f"`{arg_name}` must be float!")
@staticmethod
def check_pad_value_by_mode(op_name, pad_mode, padding):
"""Validate value of padding according to pad_mode"""
if pad_mode != 'pad' and padding != 0:
raise ValueError(f"For op '{op_name}', padding must be zero when pad_mode is '{pad_mode}'.")
return padding
@staticmethod
def check_empty_shape_input(arg_name, arg_value):
"""Check zeros value."""
if 0 in arg_value:
raise ValueError(f"Input `{arg_name}` cannot be empty.")
@staticmethod
def check_scalar_shape_input(arg_name, arg_value):
"""Check scalar shape input."""
if arg_value != []:
raise ValueError(f"Input `{arg_name}` shape should be (). got {arg_value}")
def check_int(input_param):
"""Int type judgment."""
if isinstance(input_param, int) and not isinstance(input_param, bool):
......@@ -653,30 +435,6 @@ def check_output_data(data):
raise RuntimeError('Executor return data ' + str(data) + ', please check your net or input data.')
def check_axis_type_int(axis):
"""Check axis type."""
if not isinstance(axis, int):
raise TypeError('Wrong type for axis, should be int.')
def check_axis_range(axis, rank):
"""Check axis range."""
if not -rank <= axis < rank:
raise ValueError('The axis should be in range [{}, {}),'' but got {}.'.format(-rank, rank, axis))
def check_attr_int(attr_name, attr):
"""Check int type."""
if not isinstance(attr, int):
raise TypeError("The attr {} should be int, but got {}.".format(attr_name, type(attr)))
def check_t_in_range(t):
"""Check input range."""
if t not in (mstype.float16, mstype.float32, mstype.float64, mstype.int32, mstype.int64):
raise ValueError("The param T should be (float16, float32, float64, int32, int64).")
once = _expand_tuple(1)
twice = _expand_tuple(2)
triple = _expand_tuple(3)
......
......@@ -175,7 +175,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
UpdateAdjoint(node_adjoint);
anfnode_to_adjoin_[morph] = node_adjoint;
if (cnode_morph->stop_gradient()) {
MS_LOG(WARNING) << "MapMorphism node " << morph->ToString() << " is stopped.";
MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped.";
return node_adjoint;
}
......
......@@ -19,7 +19,6 @@ from mindspore._checkparam import Validator as validator
from ... import context
from ..cell import Cell
from ..._checkparam import Rel
from ..._checkparam import ParamValidator
class _PoolNd(Cell):
......@@ -265,11 +264,11 @@ class AvgPool1d(_PoolNd):
stride=1,
pad_mode="valid"):
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
ParamValidator.check_type('kernel_size', kernel_size, [int,])
ParamValidator.check_type('stride', stride, [int,])
self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'])
ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE)
ParamValidator.check_integer("stride", stride, 1, Rel.GE)
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name)
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
self.kernel_size = (1, kernel_size)
self.stride = (1, stride)
self.avg_pool = P.AvgPool(ksize=self.kernel_size,
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test ops """
import functools
import numpy as np
from mindspore import ops
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
import mindspore.ops.composite as C
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from ..ut_filter import non_graph_engine
from mindspore.common.api import _executor
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward\
import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config,
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
from ....mindspore_test_framework.pipeline.gradient.compile_gradient\
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
class ExpandDimsNet(nn.Cell):
def __init__(self, axis):
super(ExpandDimsNet, self).__init__()
self.axis = axis
self.op = P.ExpandDims()
def construct(self, x):
return self.op(x, self.axis)
class IsInstanceNet(nn.Cell):
def __init__(self, inst):
super(IsInstanceNet, self).__init__()
self.inst = inst
self.op = P.IsInstance()
def construct(self, t):
return self.op(self.inst, t)
class ReshapeNet(nn.Cell):
def __init__(self, shape):
super(ReshapeNet, self).__init__()
self.shape = shape
self.op = P.Reshape()
def construct(self, x):
return self.op(x, self.shape)
raise_set = [
# input is scala, not Tensor
('ExpandDims0', {
'block': (P.ExpandDims(), {'exception': TypeError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [5.0, 1],
'skip': ['backward']}),
# axis is as a parameter
('ExpandDims1', {
'block': (P.ExpandDims(), {'exception': TypeError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), 1],
'skip': ['backward']}),
# axis as an attribute, but less then lower limit
('ExpandDims2', {
'block': (ExpandDimsNet(-4), {'exception': ValueError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# axis as an attribute, but greater then upper limit
('ExpandDims3', {
'block': (ExpandDimsNet(3), {'exception': ValueError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input is scala, not Tensor
('DType0', {
'block': (P.DType(), {'exception': TypeError, 'error_keywords': ['DType']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input x scala, not Tensor
('SameTypeShape0', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input y scala, not Tensor
('SameTypeShape1', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), 5.0],
'skip': ['backward']}),
# type of x and y not match
('SameTypeShape2', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.int32))],
'skip': ['backward']}),
# shape of x and y not match
('SameTypeShape3', {
'block': (P.SameTypeShape(), {'exception': ValueError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 3]).astype(np.float32))],
'skip': ['backward']}),
# sub_type is None
('IsSubClass0', {
'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
'desc_inputs': [None, mstype.number],
'skip': ['backward']}),
# type_ is None
('IsSubClass1', {
'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
'desc_inputs': [mstype.number, None],
'skip': ['backward']}),
# inst is var
('IsInstance0', {
'block': (P.IsInstance(), {'exception': ValueError, 'error_keywords': ['IsInstance']}),
'desc_inputs': [5.0, mstype.number],
'skip': ['backward']}),
# t is not mstype.Type
('IsInstance1', {
'block': (IsInstanceNet(5.0), {'exception': TypeError, 'error_keywords': ['IsInstance']}),
'desc_inputs': [None],
'skip': ['backward']}),
# input x is scalar, not Tensor
('Reshape0', {
'block': (P.Reshape(), {'exception': TypeError, 'error_keywords': ['Reshape']}),
'desc_inputs': [5.0, (1, 2)],
'skip': ['backward']}),
# input shape is var
('Reshape1', {
'block': (P.Reshape(), {'exception': TypeError, 'error_keywords': ['Reshape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), (2, 3, 2)],
'skip': ['backward']}),
# element of shape is not int
('Reshape3', {
'block': (ReshapeNet((2, 3.0, 2)), {'exception': TypeError, 'error_keywords': ['Reshape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
]
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
def test_check_exception():
return raise_set
......@@ -383,7 +383,7 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
net = NetWork()
with pytest.raises(ValueError) as ex:
net(input_tensor)
assert "The `begin[0]` should be an int and must greater or equal to -6, but got -7" in str(ex.value)
assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(ex.value)
def test_tensor_slice_reduce_out_of_bounds_positive():
......@@ -400,4 +400,4 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
net = NetWork()
with pytest.raises(ValueError) as ex:
net(input_tensor)
assert "The `begin[0]` should be an int and must less than 6, but got 6" in str(ex.value)
assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)
......@@ -16,7 +16,7 @@
import numpy as np
from mindspore._checkparam import Rel
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
def avg_pooling(x, pool_h, pool_w, stride):
......@@ -32,7 +32,7 @@ def avg_pooling(x, pool_h, pool_w, stride):
Returns:
numpy.ndarray, an output array after applying average pooling on input array.
"""
validator.check_integer("stride", stride, 0, Rel.GT)
validator.check_integer("stride", stride, 0, Rel.GT, None)
num, channel, height, width = x.shape
out_h = (height - pool_h)//stride + 1
out_w = (width - pool_w)//stride + 1
......@@ -217,7 +217,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
dilation=1, groups=1, padding_mode='zeros'):
"""Convolution 2D."""
# pylint: disable=unused-argument
validator.check_type('stride', stride, (int, tuple))
validator.check_value_type('stride', stride, (int, tuple), None)
if isinstance(stride, int):
stride = (stride, stride)
elif len(stride) == 4:
......@@ -229,7 +229,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
f"a tuple of two positive int numbers, but got {stride}")
stride_h = stride[0]
stride_w = stride[1]
validator.check_type('dilation', dilation, (int, tuple))
validator.check_value_type('dilation', dilation, (int, tuple), None)
if isinstance(dilation, int):
dilation = (dilation, dilation)
elif len(dilation) == 4:
......@@ -384,7 +384,7 @@ def matmul(x, w, b=None):
def max_pooling(x, pool_h, pool_w, stride):
"""Max pooling."""
validator.check_integer("stride", stride, 0, Rel.GT)
validator.check_integer("stride", stride, 0, Rel.GT, None)
num, channel, height, width = x.shape
out_h = (height - pool_h)//stride + 1
out_w = (width - pool_w)//stride + 1
......@@ -427,7 +427,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
def max_pool_with_argmax(x, pool_h, pool_w, stride):
"""Max pooling with argmax."""
validator.check_integer("stride", stride, 0, Rel.GT)
validator.check_integer("stride", stride, 0, Rel.GT, None)
num, channel, height, width = x.shape
out_h = (height - pool_h)//stride + 1
out_w = (width - pool_w)//stride + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册