提交 77dd91a6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5669 Fix get_py_obj_dtype() for mindspore type

Merge pull request !5669 from hewei/fix_get_py_obj_dtype
...@@ -178,12 +178,18 @@ def get_py_obj_dtype(obj): ...@@ -178,12 +178,18 @@ def get_py_obj_dtype(obj):
Type of MindSpore type. Type of MindSpore type.
""" """
# Tensor # Tensor
if hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type): if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type):
return tensor_type(obj.dtype) return tensor_type(obj.dtype)
# Primitive or Cell
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'): if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
return function return function
if isinstance(obj, (typing.Type, type)): # mindspore type
if isinstance(obj, typing.Type):
return type_type
# python type
if isinstance(obj, type):
return pytype_to_dtype(obj) return pytype_to_dtype(obj)
# others
return pytype_to_dtype(type(obj)) return pytype_to_dtype(type(obj))
......
...@@ -19,7 +19,6 @@ import mindspore as ms ...@@ -19,7 +19,6 @@ import mindspore as ms
import mindspore.ops.operations as P import mindspore.ops.operations as P
from mindspore import Tensor, context from mindspore import Tensor, context
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.common.dtype import get_py_obj_dtype
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from ...ut_filter import non_graph_engine from ...ut_filter import non_graph_engine
...@@ -90,7 +89,7 @@ def test_cast_grad(): ...@@ -90,7 +89,7 @@ def test_cast_grad():
def test_scalar_cast_grad(): def test_scalar_cast_grad():
""" test_scalar_cast_grad """ """ test_scalar_cast_grad """
input_x = 255.5 input_x = 255.5
input_t = get_py_obj_dtype(ms.int8) input_t = ms.int8
def fx_cast(x): def fx_cast(x):
output = F.scalar_cast(x, input_t) output = F.scalar_cast(x, input_t)
......
...@@ -23,7 +23,6 @@ from mindspore import context ...@@ -23,7 +23,6 @@ from mindspore import context
from mindspore.common import MetaTensor from mindspore.common import MetaTensor
from mindspore.common import dtype from mindspore.common import dtype
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.common.dtype import get_py_obj_dtype
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
...@@ -185,7 +184,7 @@ def test_input_signature(): ...@@ -185,7 +184,7 @@ def test_input_signature():
def test_scalar_cast(): def test_scalar_cast():
""" test_scalar_cast """ """ test_scalar_cast """
input_x = 8.5 input_x = 8.5
input_t = get_py_obj_dtype(ms.int64) input_t = ms.int64
@ms_function @ms_function
def fn_cast(x, t): def fn_cast(x, t):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册