提交 19e1494a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!375 Add prim name to error message for _grad_ops.py

Merge pull request !375 from fary86/add_prim_name_to_error_message_for_grad_ops
......@@ -206,8 +206,8 @@ class Validator:
def _check_tensor_type(arg):
arg_key, arg_val = arg
elem_type = arg_val
type_names = []
if not elem_type in valid_values:
type_names = []
for t in valid_values:
type_names.append(str(t))
types_info = '[' + ", ".join(type_names) + ']'
......
......@@ -15,7 +15,7 @@
"""utils for operator"""
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
......@@ -62,25 +62,25 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
return broadcast_shape
def _get_concat_offset(x_shp, x_type, axis):
def _get_concat_offset(x_shp, x_type, axis, prim_name):
"""for concat and concatoffset check args and compute offset"""
validator.check_type("shape", x_shp, [tuple])
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
validator.check_subclass("shape0", x_type[0], mstype.tensor)
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
validator.check_value_type("shape", x_shp, [tuple], prim_name)
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name)
rank_base = len(x_shp[0])
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
if axis < 0:
axis = axis + rank_base
all_shp = x_shp[0][axis]
offset = [0,]
for i in range(1, len(x_shp)):
v = x_shp[i]
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name)
for j in range(rank_base):
if j != axis and v[j] != x_shp[0][j]:
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element")
offset.append(all_shp)
all_shp += v[axis]
return offset, all_shp, axis
......@@ -18,8 +18,7 @@
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel, check_int_positive, check_bool
from ..._checkparam import Validator as validator, Rel
from .._utils import _get_concat_offset
from ...common import dtype as mstype
......@@ -51,12 +50,12 @@ class ACosGrad(PrimitiveWithInfer):
"""init ACosGrad"""
def infer_shape(self, x, dout):
validator.check_param_equal("x", x, "dout", dout)
validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
return x
def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_type_same(args, mstype.number_type)
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x
......@@ -65,8 +64,8 @@ class BatchNormGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, is_training=False, epsilon=1e-5):
self.is_training = validator.check_type('is_training', is_training, (bool,))
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT)
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name)
self.add_prim_attr('data_format', "NCHW")
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape):
......@@ -93,19 +92,19 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer):
"""Computes gradients for `BinaryCrossEntropy` operation."""
@prim_attr_register
def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'])
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name)
def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape):
validator.check_param_equal('x_shape', x_shape, 'y_shape', y_shape)
validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name)
if weight_shape:
validator.check_param_equal('y_shape', y_shape, 'weight_shape', weight_shape)
validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_type, y_type, doutput_type, weight_type):
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
validator.check_type_same(args, (mstype.float16, mstype.float32))
validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name)
if weight_type:
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError)
return x_type
......@@ -120,7 +119,7 @@ class ConcatOffset(PrimitiveWithInfer):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
offset, _, axis = _get_concat_offset(x_shp, x_type, axis)
offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name)
self.add_prim_attr('T', x_type[0].element_type())
offset_values = []
for i in range(len(x_shp)):
......@@ -184,11 +183,11 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
def __infer__(self, doutput, x, w_size):
w_size_v = w_size['value']
validator.check_type('w_size', w_size_v, [tuple])
validator.check_value_type('w_size', w_size_v, [tuple], self.name)
for i, dim_len in enumerate(w_size_v):
validator.check_type("w_size[%d]" % i, dim_len, [int])
validator.check_typename('x_dtype', x['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32])
validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'x_dtype', x['dtype'])
validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name)
args = {"x": x['dtype'], "doutput": doutput['dtype']}
validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name)
out = {
'value': None,
'shape': w_size_v,
......@@ -250,8 +249,8 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
def __infer__(self, x, w_size, dout):
w_size_v = w_size['value']
args = {'x_dtype': x['dtype'], 'dout_type': dout['dtype']}
validator.check_type_same(args, mstype.number_type)
args = {'x': x['dtype'], 'dout': dout['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
out = {
'value': None,
'shape': w_size_v,
......@@ -310,8 +309,8 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
raise NotImplementedError
def __infer__(self, x_size, w, dout):
args = {'w_dtype': w['dtype'], 'dout_type': dout['dtype']}
validator.check_type_same(args, mstype.number_type)
args = {'w': w['dtype'], 'dout': dout['dtype']}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
x_size_v = x_size['value']
out = {
'value': None,
......@@ -360,9 +359,9 @@ class GeluGrad(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype):
validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32))
validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32))
validator.check_typename("y_dtype", y_dtype, (mstype.float16, mstype.float32))
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name)
return x_dtype
......@@ -373,56 +372,36 @@ class _PoolGrad(PrimitiveWithInfer):
def __init__(self, ksize, strides, padding="VALID"):
self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
validator.check_type('ksize', ksize, [int, tuple])
validator.check_type('strides', strides, [int, tuple])
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'])
validator.check_value_type('ksize', ksize, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
self.add_prim_attr("padding", self.padding)
self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
if not self.is_maxpoolgradwithargmax:
self.add_prim_attr('data_format', "NCHW")
if isinstance(ksize, int):
validator.check_integer("ksize", ksize, 1, Rel.GE)
if self.is_maxpoolgradwithargmax:
self.ksize = (1, ksize, ksize, 1)
def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax):
validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number "
f"or a tuple of two or four positive int numbers, but got {arg_val}")
if isinstance(arg_val, int):
ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val)
elif len(arg_val) == 2:
ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1])
elif len(arg_val) == 4:
ret = arg_val
else:
self.ksize = (1, 1, ksize, ksize)
else:
ksize_error = ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number"
f"or a tuple of two or four positive int numbers, but got {ksize}")
if len(ksize) != 2 and len(ksize) != 4:
raise ksize_error
for ksize_val in ksize:
if not isinstance(ksize_val, int) or (ksize_val <= 0):
raise ksize_error
if len(ksize) == 2 and self.is_maxpoolgradwithargmax:
self.ksize = (1, ksize[0], ksize[1], 1)
elif len(ksize) == 2 and not self.is_maxpoolgradwithargmax:
self.ksize = (1, 1, ksize[0], ksize[1])
else:
self.ksize = ksize
raise error_msg
# whether all elements of tuple are positive integers
for item in ret:
if not isinstance(item, int) or item <= 0:
raise error_msg
return ret
self.ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax)
self.add_prim_attr("ksize", self.ksize)
if isinstance(strides, int):
validator.check_integer("strides", strides, 1, Rel.GE)
if self.is_maxpoolgradwithargmax:
self.strides = (1, strides, strides, 1)
else:
self.strides = (1, 1, strides, strides)
else:
strides_error = ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number"
f"or a tuple of two or four positive int numbers, but got {strides}")
if len(strides) != 2 and len(strides) != 4:
raise strides_error
for strides_val in strides:
if not isinstance(strides_val, int) or (strides_val <= 0):
raise strides_error
if len(strides) == 2 and self.is_maxpoolgradwithargmax:
self.strides = (1, strides[0], strides[1], 1)
elif len(strides) == 2 and not self.is_maxpoolgradwithargmax:
self.strides = (1, 1, strides[0], strides[1])
else:
self.strides = strides
self.strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax)
self.add_prim_attr("strides", self.strides)
......@@ -529,17 +508,17 @@ class L2NormalizeGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, axis=0, epsilon=1e-4):
validator.check_type('axis', axis, [int])
validator.check_type('epsilon', epsilon, [int, float])
validator.check_value_type('axis', axis, [int], self.name)
validator.check_value_type('epsilon', epsilon, [int, float], self.name)
def infer_shape(self, input_x, out, dout):
validator.check_param_equal('input_x', input_x, 'out', out)
validator.check_param_equal('input_x', input_x, 'dout', dout)
validator.check('input_x shape', input_x, 'out shape', out, Rel.EQ, self.name)
validator.check('input_x shape', input_x, 'dout shape', dout, Rel.EQ, self.name)
return input_x
def infer_dtype(self, input_x, out, dout):
args = {'input_x': input_x, 'out': out, 'dout': dout}
validator.check_type_same(args, mstype.number_type)
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return input_x
......@@ -560,8 +539,8 @@ class LayerNormGrad(Primitive):
@prim_attr_register
def __init__(self, begin_norm_axis=1, begin_params_axis=1):
"""init"""
self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int])
self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int])
self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name)
self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name)
def __call__(self, x, dy, variance, mean, gamma):
raise NotImplementedError
......@@ -573,15 +552,15 @@ class LogSoftmaxGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, axis=-1):
"""init LogSoftmaxGrad"""
validator.check_type("axis", axis, [int])
validator.check_value_type("axis", axis, [int], self.name)
def infer_shape(self, dout, logits):
rank = len(logits)
validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH)
validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH, self.name)
return logits
def infer_dtype(self, dout, logits):
validator.check_subclass("logits", logits, mstype.tensor)
validator.check_subclass("logits", logits, mstype.tensor, self.name)
return logits
......@@ -590,13 +569,13 @@ class LSTMGradData(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
self.input_size = check_int_positive(input_size)
self.hidden_size = check_int_positive(hidden_size)
self.num_layers = check_int_positive(num_layers)
self.has_bias = check_bool(has_bias)
self.bidirectional = check_bool(bidirectional)
self.dropout = validator.check_type("dropout", dropout, [float])
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH)
self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name)
self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name)
self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name)
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
if bidirectional:
self.num_directions = 2
......@@ -606,19 +585,19 @@ class LSTMGradData(PrimitiveWithInfer):
def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
hx_shape, cx_shape, reserve_shape, state_shape):
# dhy and dcy should be same shape
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ)
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ)
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ)
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ)
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ)
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name)
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name)
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name)
validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ)
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ)
validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name)
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name)
# dy: (seq_len, batch_size, hidden_size * num_directions)
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ)
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ)
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ)
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name)
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name)
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name)
# (seq_len, batch_size, input_size)
dx_shape = (y_shape[0], y_shape[1], self.input_size)
......@@ -629,11 +608,8 @@ class LSTMGradData(PrimitiveWithInfer):
def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
hx_dtype, cx_dtype, reserve_dtype, state_dtype):
validator.check_typename("dy_dtype", dy_dtype, (mstype.float32, mstype.float16))
validator.check_typename("dhy_dtype", dhy_dtype, (mstype.float32, mstype.float16))
validator.check_typename("dcy_dtype", dcy_dtype, (mstype.float32, mstype.float16))
validator.check_typename("datatype", dy_dtype, (dhy_dtype.element_type(),))
validator.check_typename("datatype", dy_dtype, (dcy_dtype.element_type(),))
args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype}
validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name)
return (dy_dtype, dy_dtype, dy_dtype)
......@@ -642,13 +618,13 @@ class LSTMGradWeight(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
self.input_size = check_int_positive(input_size)
self.hidden_size = check_int_positive(hidden_size)
self.num_layers = check_int_positive(num_layers)
self.has_bias = check_bool(has_bias)
self.bidirectional = check_bool(bidirectional)
self.dropout = validator.check_type("dropout", dropout, [float])
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH)
self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name)
self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name)
self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name)
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
self.dropout = validator.check_value_type("dropout", dropout, [float], self.name)
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name)
if bidirectional:
self.num_directions = 2
......@@ -693,9 +669,10 @@ class PReLUGrad(PrimitiveWithInfer):
return y_backprop_shape, w_shape
def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype):
validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32))
validator.check_typename("A_dtype", A_dtype, (mstype.float16, mstype.float32))
validator.check_typename("w_dtype", w_dtype, (mstype.float16, mstype.float32))
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name)
validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name)
return y_backprop_dtype, w_dtype
......@@ -725,8 +702,8 @@ class ReLU6Grad(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_typename("y_grad_dtype", y_grad_dtype, (mstype.float16, mstype.float32))
validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32))
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
return x_dtype
......@@ -744,10 +721,8 @@ class ReluGradV2(PrimitiveWithInfer):
return gradients_shape
def infer_dtype(self, gradients_dtype, mask_dtype):
args_type = {'gradients': gradients_dtype, 'mask': mask_dtype}
validator.check_args_tensor(args_type)
validator.check_typename("gradients_dtype", gradients_dtype, mstype.number_type)
validator.check_typename("mask_dtype", mask_dtype, (mstype.uint8,))
validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name)
validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name)
return gradients_dtype
......@@ -762,10 +737,8 @@ class EluGrad(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
args_type = {'y_grad': y_grad_dtype, 'x': x_dtype}
validator.check_args_tensor(args_type)
args_dtype = {'y_grad_dtype': y_grad_dtype, 'x_dtype': x_dtype}
validator.check_type_same(args_dtype, mstype.float_type)
args = {'y_grad': y_grad_dtype, 'x': x_dtype}
validator.check_tensor_type_same(args, mstype.float_type, self.name)
return x_dtype
......@@ -821,11 +794,11 @@ class ROIAlignGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num=2):
"""init ROIAlignGrad"""
validator.check_type("pooled_height", pooled_height, [int])
validator.check_type("pooled_width", pooled_width, [int])
validator.check_type("spatial_scale", spatial_scale, [float])
validator.check_type("sample_num", sample_num, [int])
validator.check_type("xdiff_shape", xdiff_shape, [tuple])
validator.check_value_type("pooled_height", pooled_height, [int], self.name)
validator.check_value_type("pooled_width", pooled_width, [int], self.name)
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name)
validator.check_value_type("sample_num", sample_num, [int], self.name)
validator.check_value_type("xdiff_shape", xdiff_shape, [tuple], self.name)
self.xdiff_shape = xdiff_shape
self.pooled_height = pooled_height
self.pooled_width = pooled_width
......@@ -850,10 +823,8 @@ class SigmoidGrad(PrimitiveWithInfer):
return out
def infer_dtype(self, out, dout):
validator.check_typename("dout dtype", dout, (mstype.float16, mstype.float32))
validator.check_typename("out dtype", out, (mstype.float16, mstype.float32))
args = {"out type": out, "dout type": dout}
validator.check_type_same(args, mstype.number_type)
args = {'out': out, 'dout': dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return out
......@@ -868,8 +839,8 @@ class HSigmoidGrad(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32))
validator.check_typename("x dtype", x_dtype, (mstype.float16, mstype.float32))
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
return x_dtype
......@@ -884,8 +855,8 @@ class HSwishGrad(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32))
validator.check_typename("x_ dtype", x_dtype, (mstype.float16, mstype.float32))
validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name)
return x_dtype
......@@ -898,13 +869,13 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad'])
def infer_shape(self, x_shape, y_shape, dout_shape):
validator.check_param_equal("x_shape", x_shape, "y_shape", y_shape)
validator.check_param_equal("x_shape", x_shape, "dout_shape", dout_shape)
validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name)
validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_dtype, y_dtype, dout_dtype):
args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype}
validator.check_type_same(args, mstype.number_type)
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return dout_dtype
......@@ -920,8 +891,8 @@ class SliceGrad(PrimitiveWithInfer):
dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value']
dy_shape_len = len(dy_shape)
for i in range(dy_shape_len):
validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE)
validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ)
validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name)
validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name)
return {'shape': x_shape,
'dtype': x['dtype'],
'value': None}
......@@ -935,13 +906,13 @@ class SmoothL1LossGrad(PrimitiveWithInfer):
pass
def infer_shape(self, prediction, target, dloss):
validator.check_param_equal('prediction', prediction, 'target', target)
validator.check_param_equal('prediction', prediction, 'dloss', dloss)
validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name)
validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name)
return prediction
def infer_dtype(self, prediction, target, dloss):
args = {"prediction": prediction, "target": target, 'dloss': dloss}
validator.check_type_same(args, mstype.number_type)
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return dloss
......@@ -968,11 +939,11 @@ class StridedSliceGrad(PrimitiveWithInfer):
new_axis_mask=0,
shrink_axis_mask=0):
"""init StrideSliceGrad"""
validator.check_type('begin_mask', begin_mask, [int])
validator.check_type('end_mask', end_mask, [int])
validator.check_type('ellipsis_mask', ellipsis_mask, [int])
validator.check_type('new_axis_mask', new_axis_mask, [int])
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int])
validator.check_value_type('begin_mask', begin_mask, [int], self.name)
validator.check_value_type('end_mask', end_mask, [int], self.name)
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
def __infer__(self, dy, shapex, begin, end, strides):
......@@ -992,10 +963,8 @@ class TanhGrad(PrimitiveWithInfer):
return out
def infer_dtype(self, out, dout):
validator.check_subclass("out", out, mstype.tensor)
validator.check_subclass("dout", dout, mstype.tensor)
args = {"out type": out, "dout type": dout}
validator.check_type_same(args, mstype.number_type)
args = {"out": out, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return out
......@@ -1005,13 +974,13 @@ class MirrorPadGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, mode="REFLECT"):
"""init MirrorPad"""
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'])
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
self.mode = mode
def __infer__(self, dout, paddings, x):
validator.check_subclass("dout", dout['dtype'], mstype.tensor)
validator.check_subclass("paddings", paddings['dtype'], mstype.tensor)
validator.check_subclass("input_x", x['dtype'], mstype.tensor)
validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name)
validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name)
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
return {'shape': x['shape'],
'dtype': dout['dtype'],
'value': None}
......
......@@ -1316,7 +1316,7 @@ class Concat(PrimitiveWithInfer):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis)
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name)
self.add_prim_attr('T', x_type[0].element_type())
self.add_prim_attr('inputNums', len(x_shp))
ret_shp = x_shp[0].copy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册