提交 77dd91a6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5669 Fix get_py_obj_dtype() for mindspore type

Merge pull request !5669 from hewei/fix_get_py_obj_dtype
......@@ -178,12 +178,18 @@ def get_py_obj_dtype(obj):
Type of MindSpore type.
"""
# Tensor
if hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type):
if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type):
return tensor_type(obj.dtype)
# Primitive or Cell
if hasattr(obj, '__primitive_flag__') or hasattr(obj, 'construct'):
return function
if isinstance(obj, (typing.Type, type)):
# mindspore type
if isinstance(obj, typing.Type):
return type_type
# python type
if isinstance(obj, type):
return pytype_to_dtype(obj)
# others
return pytype_to_dtype(type(obj))
......
......@@ -19,7 +19,6 @@ import mindspore as ms
import mindspore.ops.operations as P
from mindspore import Tensor, context
from mindspore.common.api import ms_function
from mindspore.common.dtype import get_py_obj_dtype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from ...ut_filter import non_graph_engine
......@@ -90,7 +89,7 @@ def test_cast_grad():
def test_scalar_cast_grad():
""" test_scalar_cast_grad """
input_x = 255.5
input_t = get_py_obj_dtype(ms.int8)
input_t = ms.int8
def fx_cast(x):
output = F.scalar_cast(x, input_t)
......
......@@ -23,7 +23,6 @@ from mindspore import context
from mindspore.common import MetaTensor
from mindspore.common import dtype
from mindspore.common.api import ms_function
from mindspore.common.dtype import get_py_obj_dtype
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from ..ut_filter import non_graph_engine
......@@ -185,7 +184,7 @@ def test_input_signature():
def test_scalar_cast():
""" test_scalar_cast """
input_x = 8.5
input_t = get_py_obj_dtype(ms.int64)
input_t = ms.int64
@ms_function
def fn_cast(x, t):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册