提交 62aed2ff 编写于 作者: B BowenK

fix param check for check_elim

上级 7f891f62
...@@ -35,6 +35,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_ ...@@ -35,6 +35,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_
from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype from ..._c_expression import signature_dtype as sig_dtype
from ..._c_expression import typing
def _check_infer_attr_reduce(axis, keep_dims, prim_name): def _check_infer_attr_reduce(axis, keep_dims, prim_name):
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name) validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
...@@ -197,7 +198,6 @@ class Cast(PrimitiveWithInfer): ...@@ -197,7 +198,6 @@ class Cast(PrimitiveWithInfer):
if data.dtype == dtype: if data.dtype == dtype:
return (True, x) return (True, x)
return (False, None) return (False, None)
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})")
def __infer__(self, x, t): def __infer__(self, x, t):
src_type = x['dtype'] src_type = x['dtype']
...@@ -1233,10 +1233,8 @@ class Tile(PrimitiveWithInfer): ...@@ -1233,10 +1233,8 @@ class Tile(PrimitiveWithInfer):
def check_elim(self, base_tensor, multiplier): def check_elim(self, base_tensor, multiplier):
if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)): if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)):
raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier)) raise TypeError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
def is_all_zeros(v_tuple): if all(v == 1 for v in multiplier):
return all(v == 1 for v in v_tuple)
if is_all_zeros(multiplier):
return (True, base_tensor) return (True, base_tensor)
return (False, None) return (False, None)
...@@ -1246,8 +1244,7 @@ class Tile(PrimitiveWithInfer): ...@@ -1246,8 +1244,7 @@ class Tile(PrimitiveWithInfer):
validator.check_value_type("shape", multiples_v, [tuple], self.name) validator.check_value_type("shape", multiples_v, [tuple], self.name)
for i, multiple in enumerate(multiples_v): for i, multiple in enumerate(multiples_v):
validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name) validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
valid_types = [mstype.int16, mstype.int32, mstype.bool_, mstype.float16, mstype.float32] validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name)
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, self.name)
len_sub = len(multiples_v) - len(x_shp) len_sub = len(multiples_v) - len(x_shp)
multiples_w = None multiples_w = None
if len_sub == 0: if len_sub == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册