From 9785178b8828d3522ce7431d1513148360ed176e Mon Sep 17 00:00:00 2001 From: kingfo Date: Tue, 2 Jun 2020 21:01:56 +0800 Subject: [PATCH] add tensor compare & len & constexpr operation --- mindspore/ccsrc/pynative/pynative_execute.cc | 5 +++ mindspore/common/tensor.py | 34 +++++++++++++++---- mindspore/ops/functional.py | 8 +++++ mindspore/ops/primitive.py | 1 + .../python/pynative_mode/test_parse_method.py | 9 +++++ 5 files changed, 51 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index 4f2a96139..02e9ebabb 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -531,6 +531,11 @@ py::tuple RunOp(const py::args &args) { value_ret[0] = output["value"]; return value_ret; } + if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { + py::tuple value_ret(1); + value_ret[0] = ""; + return value_ret; + } } MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; mindspore::parse::python_adapter::set_python_env_flag(true); diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index cbb705f84..ed8f1be0f 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -71,19 +71,18 @@ class Tensor(Tensor_): return str(self.__str__()) def __add__(self, other): - check_type('tensor input_data', other, (Tensor, float, int)) out = tensor_operator_registry.get('__add__')(self, other) return out def __eq__(self, other): if not isinstance(other, Tensor): return False - return Tensor(np.array(self.asnumpy() == other.asnumpy())) + return tensor_operator_registry.get('__eq__')(self, other) def __ne__(self, other): if not isinstance(other, Tensor): return True - return Tensor(np.array(self.asnumpy() != other.asnumpy())) + return tensor_operator_registry.get('__ne__')(self, other) def __hash__(self): return hash(id(self)) @@ -93,7 +92,8 @@ class Tensor(Tensor_): return out def __neg__(self): - return Tensor(-self.asnumpy()) + out = tensor_operator_registry.get('__neg__')(self) + return out def __iadd__(self, other): out = self.__add__(other) @@ -120,7 +120,7 @@ class Tensor(Tensor_): return out def __sub__(self, other): - out = self.__add__(-other) + out = tensor_operator_registry.get('__sub__')(self, other) return out def __isub__(self, other): @@ -128,9 +128,31 @@ class Tensor(Tensor_): return out def __rsub__(self, other): - out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy())) + out = tensor_operator_registry.get('__sub__')(other, self) + return out + + def __lt__(self, other): + out = tensor_operator_registry.get('__lt__')(self, other) + return out + + def __le__(self, other): + out = tensor_operator_registry.get('__le__')(self, other) return out + def __gt__(self, other): + out = tensor_operator_registry.get('__gt__')(self, other) + return out + + def __ge__(self, other): + out = tensor_operator_registry.get('__ge__')(self, other) + return out + + def __len__(self): + out = tensor_operator_registry.get('__shape__')(self) + if not out: + return 1 + return out[0] + def __str__(self): if self.dtype() == mstype.type_none: return "Unknown Tensor type!" diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 8f5fcaefb..5edc2f809 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -151,7 +151,15 @@ shape_mul = Primitive("shape_mul") stop_gradient = Primitive("stop_gradient") tensor_operator_registry.register('__add__', tensor_add) +tensor_operator_registry.register('__sub__', tensor_sub) tensor_operator_registry.register('__mul__', tensor_mul) tensor_operator_registry.register('__div__', tensor_div) #ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) +tensor_operator_registry.register('__ne__', not_equal) +tensor_operator_registry.register('__neg__', neg_tensor) +tensor_operator_registry.register('__lt__', tensor_lt) +tensor_operator_registry.register('__le__', tensor_le) +tensor_operator_registry.register('__gt__', tensor_gt) +tensor_operator_registry.register('__ge__', tensor_ge) +tensor_operator_registry.register('__shape__', shape) diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index beaf9b8a4..f456421f7 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -310,6 +310,7 @@ def constexpr(fn=None, get_instance=True, name=None): def __init__(self): op_name = name if name else fn.__name__ PrimitiveWithInfer.__init__(self, op_name) + self.const_value = True def infer_value(self, *args): return fn(*args) diff --git a/tests/ut/python/pynative_mode/test_parse_method.py b/tests/ut/python/pynative_mode/test_parse_method.py index a4b2907cc..abbfa6cd3 100644 --- a/tests/ut/python/pynative_mode/test_parse_method.py +++ b/tests/ut/python/pynative_mode/test_parse_method.py @@ -29,6 +29,7 @@ from mindspore._extends.parse.standard_method import ms_len from mindspore.common.api import ms_function from mindspore.common.tensor import Tensor from mindspore.ops.composite import core +from mindspore.ops.primitive import constexpr from ..ut_filter import non_graph_engine @@ -417,3 +418,11 @@ def test_range(): """ test_range """ res = range_spec(10, 10) return res + +def test_expr(): + """ test const expr """ + a = (1, 2) + @constexpr + def tuple_len(x): + assert len(x) == 2 + tuple_len(a) -- GitLab