diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 0b44e83b2f3ef3b46e0046a4f6eb47bb6ddb9223..26f0c98223681ece138fdab8424c10807425b321 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -35,6 +35,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_ from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind from ..._c_expression import signature_dtype as sig_dtype +from ..._c_expression import typing def _check_infer_attr_reduce(axis, keep_dims, prim_name): validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) @@ -196,8 +197,7 @@ class Cast(PrimitiveWithInfer): data = x.default_input if data.dtype == dtype: return (True, x) - return (False, None) - raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})") + return (False, None) def __infer__(self, x, t): src_type = x['dtype'] @@ -1233,10 +1233,8 @@ class Tile(PrimitiveWithInfer): def check_elim(self, base_tensor, multiplier): if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)): - raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier)) - def is_all_zeros(v_tuple): - return all(v == 1 for v in v_tuple) - if is_all_zeros(multiplier): + raise TypeError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier)) + if all(v == 1 for v in multiplier): return (True, base_tensor) return (False, None) @@ -1246,8 +1244,7 @@ class Tile(PrimitiveWithInfer): validator.check_value_type("shape", multiples_v, [tuple], self.name) for i, multiple in enumerate(multiples_v): validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name) - valid_types = [mstype.int16, mstype.int32, mstype.bool_, mstype.float16, mstype.float32] - validator.check_tensor_type_same({'x': x['dtype']}, valid_types, self.name) + validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name) len_sub = len(multiples_v) - len(x_shp) multiples_w = None if len_sub == 0: