From 35a57e076d5dca2a49f69bf467608b08ff65ba11 Mon Sep 17 00:00:00 2001 From: BowenK Date: Mon, 15 Jun 2020 10:07:19 +0800 Subject: [PATCH] fix cast check elim --- mindspore/ops/operations/array_ops.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 89880315a..9ced3684d 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -186,11 +186,13 @@ class Cast(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) def check_elim(self, x, dtype): - if isinstance(x, Tensor): - if x.dtype == dtype: + if isinstance(x, (Tensor, numbers.Number)): + if isinstance(x, Tensor) and x.dtype == dtype: return (True, x) + if isinstance(x, numbers.Number): + return (True, Tensor(x, dtype=dtype)) return (False, None) - raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs)) + raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})") def __infer__(self, x, t): src_type = x['dtype'] -- GitLab