提交 9785178b 编写于 作者: K kingfo

add tensor compare & len & constexpr operation

上级 5c4731b7
...@@ -531,6 +531,11 @@ py::tuple RunOp(const py::args &args) { ...@@ -531,6 +531,11 @@ py::tuple RunOp(const py::args &args) {
value_ret[0] = output["value"]; value_ret[0] = output["value"];
return value_ret; return value_ret;
} }
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
py::tuple value_ret(1);
value_ret[0] = "";
return value_ret;
}
} }
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true); mindspore::parse::python_adapter::set_python_env_flag(true);
......
...@@ -71,19 +71,18 @@ class Tensor(Tensor_): ...@@ -71,19 +71,18 @@ class Tensor(Tensor_):
return str(self.__str__()) return str(self.__str__())
def __add__(self, other): def __add__(self, other):
check_type('tensor input_data', other, (Tensor, float, int))
out = tensor_operator_registry.get('__add__')(self, other) out = tensor_operator_registry.get('__add__')(self, other)
return out return out
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Tensor): if not isinstance(other, Tensor):
return False return False
return Tensor(np.array(self.asnumpy() == other.asnumpy())) return tensor_operator_registry.get('__eq__')(self, other)
def __ne__(self, other): def __ne__(self, other):
if not isinstance(other, Tensor): if not isinstance(other, Tensor):
return True return True
return Tensor(np.array(self.asnumpy() != other.asnumpy())) return tensor_operator_registry.get('__ne__')(self, other)
def __hash__(self): def __hash__(self):
return hash(id(self)) return hash(id(self))
...@@ -93,7 +92,8 @@ class Tensor(Tensor_): ...@@ -93,7 +92,8 @@ class Tensor(Tensor_):
return out return out
def __neg__(self): def __neg__(self):
return Tensor(-self.asnumpy()) out = tensor_operator_registry.get('__neg__')(self)
return out
def __iadd__(self, other): def __iadd__(self, other):
out = self.__add__(other) out = self.__add__(other)
...@@ -120,7 +120,7 @@ class Tensor(Tensor_): ...@@ -120,7 +120,7 @@ class Tensor(Tensor_):
return out return out
def __sub__(self, other): def __sub__(self, other):
out = self.__add__(-other) out = tensor_operator_registry.get('__sub__')(self, other)
return out return out
def __isub__(self, other): def __isub__(self, other):
...@@ -128,9 +128,31 @@ class Tensor(Tensor_): ...@@ -128,9 +128,31 @@ class Tensor(Tensor_):
return out return out
def __rsub__(self, other): def __rsub__(self, other):
out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy())) out = tensor_operator_registry.get('__sub__')(other, self)
return out
def __lt__(self, other):
out = tensor_operator_registry.get('__lt__')(self, other)
return out
def __le__(self, other):
out = tensor_operator_registry.get('__le__')(self, other)
return out return out
def __gt__(self, other):
out = tensor_operator_registry.get('__gt__')(self, other)
return out
def __ge__(self, other):
out = tensor_operator_registry.get('__ge__')(self, other)
return out
def __len__(self):
out = tensor_operator_registry.get('__shape__')(self)
if not out:
return 1
return out[0]
def __str__(self): def __str__(self):
if self.dtype() == mstype.type_none: if self.dtype() == mstype.type_none:
return "Unknown Tensor type!" return "Unknown Tensor type!"
......
...@@ -151,7 +151,15 @@ shape_mul = Primitive("shape_mul") ...@@ -151,7 +151,15 @@ shape_mul = Primitive("shape_mul")
stop_gradient = Primitive("stop_gradient") stop_gradient = Primitive("stop_gradient")
tensor_operator_registry.register('__add__', tensor_add) tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry.register('__mul__', tensor_mul) tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__div__', tensor_div) tensor_operator_registry.register('__div__', tensor_div)
#ms cannot support Tensor(True) compare #ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)
tensor_operator_registry.register('__neg__', neg_tensor)
tensor_operator_registry.register('__lt__', tensor_lt)
tensor_operator_registry.register('__le__', tensor_le)
tensor_operator_registry.register('__gt__', tensor_gt)
tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('__shape__', shape)
...@@ -310,6 +310,7 @@ def constexpr(fn=None, get_instance=True, name=None): ...@@ -310,6 +310,7 @@ def constexpr(fn=None, get_instance=True, name=None):
def __init__(self): def __init__(self):
op_name = name if name else fn.__name__ op_name = name if name else fn.__name__
PrimitiveWithInfer.__init__(self, op_name) PrimitiveWithInfer.__init__(self, op_name)
self.const_value = True
def infer_value(self, *args): def infer_value(self, *args):
return fn(*args) return fn(*args)
......
...@@ -29,6 +29,7 @@ from mindspore._extends.parse.standard_method import ms_len ...@@ -29,6 +29,7 @@ from mindspore._extends.parse.standard_method import ms_len
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops.composite import core from mindspore.ops.composite import core
from mindspore.ops.primitive import constexpr
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
...@@ -417,3 +418,11 @@ def test_range(): ...@@ -417,3 +418,11 @@ def test_range():
""" test_range """ """ test_range """
res = range_spec(10, 10) res = range_spec(10, 10)
return res return res
def test_expr():
""" test const expr """
a = (1, 2)
@constexpr
def tuple_len(x):
assert len(x) == 2
tuple_len(a)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册