diff --git a/mindspore/ccsrc/ir/meta_tensor.cc b/mindspore/ccsrc/ir/meta_tensor.cc index fe41abcef42eb1dd18f02ee66014f99d0b66c04e..af6b4f7ffcf9c63da1f255a52f41b84eec93e26a 100644 --- a/mindspore/ccsrc/ir/meta_tensor.cc +++ b/mindspore/ccsrc/ir/meta_tensor.cc @@ -185,14 +185,6 @@ bool Tensor::operator==(const Tensor &tensor) const { return (MetaTensor::operator==(tensor) && data_ == tensor.data_); } -bool Tensor::ValueEqualPy(const py::object &other) const { - if (!py::isinstance(other)) { - MS_LOG(WARNING) << "compare other not a tensor"; - return false; - } - return ValueEqual(py::cast(other)); -} - bool Tensor::ValueEqual(const Tensor &other) const { auto equal = [&other, this]() -> bool { auto np = py::module::import("numpy"); @@ -542,7 +534,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { )mydelimiter") .def("__str__", &Tensor::ToString) .def("__repr__", &Tensor::ToStringRepr) - .def("__eq__", &Tensor::ValueEqualPy) .def(py::pickle( [](const Tensor &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ diff --git a/mindspore/ccsrc/ir/meta_tensor.h b/mindspore/ccsrc/ir/meta_tensor.h index 1f6c866f11cc4983ac5bbbefd88f8f60629bc4b5..ff76a1d4f9da57b784d860c26648cd613b789951 100644 --- a/mindspore/ccsrc/ir/meta_tensor.h +++ b/mindspore/ccsrc/ir/meta_tensor.h @@ -329,9 +329,6 @@ class Tensor : public MetaTensor { // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. bool ValueEqual(const Tensor &other) const; - // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. - bool ValueEqualPy(const py::object &other) const; - bool operator==(const Value &other) const override { if (other.isa()) { auto other_ = static_cast(other); diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 70b8b169ca1466e6632bd80c85e2f4f593ce01e8..5504f2b483924ed6d4c7563c519aedf9910e1faf 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -74,6 +74,17 @@ class Tensor(Tensor_): out = tensor_operator_registry.get('__add__')(self, other) return out + def __eq__(self, other): + if not isinstance(other, Tensor): + return False + x = self.asnumpy() + y = other.asnumpy() + out = np.equal(x, y) + return Tensor(np.array(out)) + + def __hash__(self): + return hash(id(self)) + def __mul__(self, other): check_type('tensor input_data', other, (Tensor, float, int)) out = tensor_operator_registry.get('__mul__')(self, other) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 4135133e855af58d9cd0520cca34b8ae45c40994..a2473fe7095c90201b468dd1c7e92468c83d6c69 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -144,3 +144,5 @@ stop_gradient = Primitive("stop_gradient") tensor_operator_registry.register('__add__', tensor_add) tensor_operator_registry.register('__mul__', tensor_mul) tensor_operator_registry.register('__div__', tensor_div) +#ms cannot support Tensor(True) compare +tensor_operator_registry.register('__eq__', equal) diff --git a/tests/vm_impl/math_ops_vm_impl.py b/tests/vm_impl/math_ops_vm_impl.py index 01df0b824e280f31ced734a29f4fbad831ecd63c..e42ba92d5e83766de6135b14f4fde3b8ad7f3210 100644 --- a/tests/vm_impl/math_ops_vm_impl.py +++ b/tests/vm_impl/math_ops_vm_impl.py @@ -172,7 +172,7 @@ def vm_impl_equal(self): x = x.asnumpy() y = y.asnumpy() out = vm.equal(x, y) - return Tensor(out) + return Tensor(np.array(out)) return vm_impl @@ -183,7 +183,7 @@ def vm_impl_not_equal(self): x = x.asnumpy() y = y.asnumpy() out = vm.not_equal(x, y) - return Tensor(out) + return Tensor(np.array(out)) return vm_impl @@ -194,7 +194,7 @@ def vm_impl_greater(self): x = x.asnumpy() y = y.asnumpy() out = vm.greater(x, y) - return Tensor(out) + return Tensor(np.array(out)) return vm_impl @vm_impl_getters.register(P.Maximum) @@ -219,17 +219,17 @@ def vm_impl_minimum(self): return vm_impl @vm_impl_getters.register(P.Less) -def vm_impl_greater(self): +def vm_impl_less(self): """Generate vm_impl function for Less""" def vm_impl(x, y): x = x.asnumpy() y = y.asnumpy() out = vm.less(x, y) - return Tensor(out) + return Tensor(np.array(out)) return vm_impl @vm_impl_getters.register(P.ScalarCast) -def vm_impl_greater(self): +def vm_impl_scalar_cast(self): """Generate vm_impl function for ScalarCast""" def vm_impl(x, t): np_type = dtype_to_nptype(t)