From c1469d61ceadc53eb5adba8b386d81101ccabc10 Mon Sep 17 00:00:00 2001 From: He Wei Date: Wed, 2 Sep 2020 15:28:32 +0800 Subject: [PATCH] Fix get_py_obj_dtype() for mindspore type Return mstype.type_type if input is mindspore type object, for example: ``` get_py_obj_dtype(mstype.float32) ---> mstype.type_type ``` --- mindspore/common/dtype.py | 10 ++++++++-- tests/ut/python/pynative_mode/ops/test_grad.py | 3 +-- tests/ut/python/pynative_mode/test_staging.py | 3 +-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index 47a814fc2..b940d1b87 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -178,12 +178,18 @@ def get_py_obj_dtype(obj): Type of MindSpore type. """ # Tensor - if hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): + if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): return tensor_type(obj.dtype) + # Primitive or Cell if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): return function - if isinstance(obj, (typing.Type, type)): + # mindspore type + if isinstance(obj, typing.Type): + return type_type + # python type + if isinstance(obj, type): return pytype_to_dtype(obj) + # others return pytype_to_dtype(type(obj)) diff --git a/tests/ut/python/pynative_mode/ops/test_grad.py b/tests/ut/python/pynative_mode/ops/test_grad.py index 1e8849cc9..b92869079 100644 --- a/tests/ut/python/pynative_mode/ops/test_grad.py +++ b/tests/ut/python/pynative_mode/ops/test_grad.py @@ -19,7 +19,6 @@ import mindspore as ms import mindspore.ops.operations as P from mindspore import Tensor, context from mindspore.common.api import ms_function -from mindspore.common.dtype import get_py_obj_dtype from mindspore.ops import composite as C from mindspore.ops import functional as F from ...ut_filter import non_graph_engine @@ -90,7 +89,7 @@ def test_cast_grad(): def test_scalar_cast_grad(): """ test_scalar_cast_grad """ input_x = 255.5 - input_t = get_py_obj_dtype(ms.int8) + input_t = ms.int8 def fx_cast(x): output = F.scalar_cast(x, input_t) diff --git a/tests/ut/python/pynative_mode/test_staging.py b/tests/ut/python/pynative_mode/test_staging.py index 7bf2fd47a..c5e739673 100644 --- a/tests/ut/python/pynative_mode/test_staging.py +++ b/tests/ut/python/pynative_mode/test_staging.py @@ -23,7 +23,6 @@ from mindspore import context from mindspore.common import MetaTensor from mindspore.common import dtype from mindspore.common.api import ms_function -from mindspore.common.dtype import get_py_obj_dtype from mindspore.ops import functional as F from mindspore.ops import operations as P from ..ut_filter import non_graph_engine @@ -185,7 +184,7 @@ def test_input_signature(): def test_scalar_cast(): """ test_scalar_cast """ input_x = 8.5 - input_t = get_py_obj_dtype(ms.int64) + input_t = ms.int64 @ms_function def fn_cast(x, t): -- GitLab