提交 35a57e07 编写于 作者: B BowenK

fix cast check elim

上级 b0963833
......@@ -186,11 +186,13 @@ class Cast(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
def check_elim(self, x, dtype):
if isinstance(x, Tensor):
if x.dtype == dtype:
if isinstance(x, (Tensor, numbers.Number)):
if isinstance(x, Tensor) and x.dtype == dtype:
return (True, x)
if isinstance(x, numbers.Number):
return (True, Tensor(x, dtype=dtype))
return (False, None)
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs))
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})")
def __infer__(self, x, t):
src_type = x['dtype']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册