diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 89880315ace22492f096e75d9b4d76f05e196fd5..9ced3684da276e05f102cc2ec3c7d1624867ff06 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -186,11 +186,13 @@ class Cast(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) def check_elim(self, x, dtype): - if isinstance(x, Tensor): - if x.dtype == dtype: + if isinstance(x, (Tensor, numbers.Number)): + if isinstance(x, Tensor) and x.dtype == dtype: return (True, x) + if isinstance(x, numbers.Number): + return (True, Tensor(x, dtype=dtype)) return (False, None) - raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs)) + raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})") def __infer__(self, x, t): src_type = x['dtype']