提交 4cc798c6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2502 1. fix infer value bug 2. tensor init support numpy number

Merge pull request !2502 from geekun/master_fix_issue
...@@ -22,6 +22,10 @@ from . import dtype as mstype ...@@ -22,6 +22,10 @@ from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry from ._register_for_tensor import tensor_operator_registry
__all__ = ['Tensor', 'MetaTensor'] __all__ = ['Tensor', 'MetaTensor']
np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_)
class Tensor(Tensor_): class Tensor(Tensor_):
...@@ -54,6 +58,10 @@ class Tensor(Tensor_): ...@@ -54,6 +58,10 @@ class Tensor(Tensor_):
""" """
def __init__(self, input_data, dtype=None): def __init__(self, input_data, dtype=None):
# If input data is numpy number, convert it to np array
if isinstance(input_data, np_types):
input_data = np.array(input_data)
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method. # If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
check_type('tensor input_data', input_data, (Tensor_, float, int)) check_type('tensor input_data', input_data, (Tensor_, float, int))
if dtype is not None: if dtype is not None:
......
...@@ -888,7 +888,8 @@ class Neg(PrimitiveWithInfer): ...@@ -888,7 +888,8 @@ class Neg(PrimitiveWithInfer):
def infer_value(self, input_x): def infer_value(self, input_x):
if input_x is not None: if input_x is not None:
input_x = input_x.asnumpy() input_x = input_x.asnumpy()
return Tensor(-input_x) out = np.array(-input_x, input_x.dtype)
return Tensor(out)
return None return None
...@@ -1667,7 +1668,8 @@ class Div(_MathBinaryOp): ...@@ -1667,7 +1668,8 @@ class Div(_MathBinaryOp):
if x is not None and y is not None: if x is not None and y is not None:
x = x.asnumpy() x = x.asnumpy()
y = y.asnumpy() y = y.asnumpy()
return Tensor(x / y) out = np.array(x / y, x.dtype)
return Tensor(out)
return None return None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册