diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index 47a814fc2330217d0bad783f8528fa644097d896..b940d1b8795d0265d35d7d39a50647c34322b824 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -178,12 +178,18 @@ def get_py_obj_dtype(obj): Type of MindSpore type. """ # Tensor - if hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): + if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): return tensor_type(obj.dtype) + # Primitive or Cell if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): return function - if isinstance(obj, (typing.Type, type)): + # mindspore type + if isinstance(obj, typing.Type): + return type_type + # python type + if isinstance(obj, type): return pytype_to_dtype(obj) + # others return pytype_to_dtype(type(obj)) diff --git a/tests/ut/python/pynative_mode/ops/test_grad.py b/tests/ut/python/pynative_mode/ops/test_grad.py index 1e8849cc970b1f8dab5ac67a178fc59a3182b362..b92869079332b1323872ca5d4e4519c141693ddb 100644 --- a/tests/ut/python/pynative_mode/ops/test_grad.py +++ b/tests/ut/python/pynative_mode/ops/test_grad.py @@ -19,7 +19,6 @@ import mindspore as ms import mindspore.ops.operations as P from mindspore import Tensor, context from mindspore.common.api import ms_function -from mindspore.common.dtype import get_py_obj_dtype from mindspore.ops import composite as C from mindspore.ops import functional as F from ...ut_filter import non_graph_engine @@ -90,7 +89,7 @@ def test_cast_grad(): def test_scalar_cast_grad(): """ test_scalar_cast_grad """ input_x = 255.5 - input_t = get_py_obj_dtype(ms.int8) + input_t = ms.int8 def fx_cast(x): output = F.scalar_cast(x, input_t) diff --git a/tests/ut/python/pynative_mode/test_staging.py b/tests/ut/python/pynative_mode/test_staging.py index 7bf2fd47a42108c6bead2b9eec844529b7b14e7d..c5e739673975e9d045c55b369e0d33edf2b38e1e 100644 --- a/tests/ut/python/pynative_mode/test_staging.py +++ b/tests/ut/python/pynative_mode/test_staging.py @@ -23,7 +23,6 @@ from mindspore import context from mindspore.common import MetaTensor from mindspore.common import dtype from mindspore.common.api import ms_function -from mindspore.common.dtype import get_py_obj_dtype from mindspore.ops import functional as F from mindspore.ops import operations as P from ..ut_filter import non_graph_engine @@ -185,7 +184,7 @@ def test_input_signature(): def test_scalar_cast(): """ test_scalar_cast """ input_x = 8.5 - input_t = get_py_obj_dtype(ms.int64) + input_t = ms.int64 @ms_function def fn_cast(x, t):