提交 19e1494a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!375 Add prim name to error message for _grad_ops.py

Merge pull request !375 from fary86/add_prim_name_to_error_message_for_grad_ops
......@@ -206,8 +206,8 @@ class Validator:
def _check_tensor_type(arg):
arg_key, arg_val = arg
elem_type = arg_val
type_names = []
if not elem_type in valid_values:
type_names = []
for t in valid_values:
type_names.append(str(t))
types_info = '[' + ", ".join(type_names) + ']'
......
......@@ -15,7 +15,7 @@
"""utils for operator"""
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
......@@ -62,25 +62,25 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name):
return broadcast_shape
def _get_concat_offset(x_shp, x_type, axis):
def _get_concat_offset(x_shp, x_type, axis, prim_name):
"""for concat and concatoffset check args and compute offset"""
validator.check_type("shape", x_shp, [tuple])
validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT)
validator.check_subclass("shape0", x_type[0], mstype.tensor)
validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT)
validator.check_value_type("shape", x_shp, [tuple], prim_name)
validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name)
validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name)
validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name)
rank_base = len(x_shp[0])
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH)
validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name)
if axis < 0:
axis = axis + rank_base
all_shp = x_shp[0][axis]
offset = [0,]
for i in range(1, len(x_shp)):
v = x_shp[i]
validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0]))
validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0])
validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name)
validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name)
for j in range(rank_base):
if j != axis and v[j] != x_shp[0][j]:
raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i)
raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element")
offset.append(all_shp)
all_shp += v[axis]
return offset, all_shp, axis
......@@ -1316,7 +1316,7 @@ class Concat(PrimitiveWithInfer):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis)
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name)
self.add_prim_attr('T', x_type[0].element_type())
self.add_prim_attr('inputNums', len(x_shp))
ret_shp = x_shp[0].copy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册