提交 62aed2ff 编写于 作者: B BowenK

fix param check for check_elim

上级 7f891f62
......@@ -35,6 +35,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype
from ..._c_expression import typing
def _check_infer_attr_reduce(axis, keep_dims, prim_name):
validator.check_value_type('keep_dims', keep_dims, [bool], prim_name)
......@@ -196,8 +197,7 @@ class Cast(PrimitiveWithInfer):
data = x.default_input
if data.dtype == dtype:
return (True, x)
return (False, None)
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})")
return (False, None)
def __infer__(self, x, t):
src_type = x['dtype']
......@@ -1233,10 +1233,8 @@ class Tile(PrimitiveWithInfer):
def check_elim(self, base_tensor, multiplier):
if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)):
raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
def is_all_zeros(v_tuple):
return all(v == 1 for v in v_tuple)
if is_all_zeros(multiplier):
raise TypeError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
if all(v == 1 for v in multiplier):
return (True, base_tensor)
return (False, None)
......@@ -1246,8 +1244,7 @@ class Tile(PrimitiveWithInfer):
validator.check_value_type("shape", multiples_v, [tuple], self.name)
for i, multiple in enumerate(multiples_v):
validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name)
valid_types = [mstype.int16, mstype.int32, mstype.bool_, mstype.float16, mstype.float32]
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, self.name)
validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name)
len_sub = len(multiples_v) - len(x_shp)
multiples_w = None
if len_sub == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册