未验证 提交 8ddf51ff 编写于 作者: X xiaoguoguo626807 提交者: GitHub

add eq and hash (#55909)

上级 275a8102
...@@ -176,7 +176,13 @@ void BindValue(py::module *m) { ...@@ -176,7 +176,13 @@ void BindValue(py::module *m) {
.def("get_defining_op", .def("get_defining_op",
&Value::GetDefiningOp, &Value::GetDefiningOp,
return_value_policy::reference) return_value_policy::reference)
.def("__eq__", &Value::operator==); .def("__eq__", &Value::operator==)
.def("__eq__",
[](Value &self, OpResult &other) {
return self.impl() == other.value_impl();
})
.def("__hash__",
[](const Value &self) { return std::hash<ir::Value>{}(self); });
} }
void BindOpOperand(py::module *m) { void BindOpOperand(py::module *m) {
...@@ -218,7 +224,8 @@ void BindOpResult(py::module *m) { ...@@ -218,7 +224,8 @@ void BindOpResult(py::module *m) {
ir::ArrayAttribute::get(ir::IrContext::Instance(), ir::ArrayAttribute::get(ir::IrContext::Instance(),
stop_gradients)); stop_gradients));
}) })
.def("get_stop_gradient", [](OpResult &self) { .def("get_stop_gradient",
[](OpResult &self) {
auto *defining_op = self.owner(); auto *defining_op = self.owner();
if (defining_op->HasAttribute(kAttrStopGradients)) { if (defining_op->HasAttribute(kAttrStopGradients)) {
auto stop_gradients = defining_op->attribute(kAttrStopGradients) auto stop_gradients = defining_op->attribute(kAttrStopGradients)
...@@ -230,6 +237,14 @@ void BindOpResult(py::module *m) { ...@@ -230,6 +237,14 @@ void BindOpResult(py::module *m) {
} else { } else {
return false; return false;
} }
})
.def("__eq__", &OpResult::operator==)
.def("__eq__",
[](OpResult &self, Value &other) {
return self.value_impl() == other.impl();
})
.def("__hash__", [](OpResult &self) {
return std::hash<ir::Value>{}(self.dyn_cast<ir::Value>());
}); });
} }
......
...@@ -122,6 +122,15 @@ detail::OpResultImpl *OpResult::impl() const { ...@@ -122,6 +122,15 @@ detail::OpResultImpl *OpResult::impl() const {
return reinterpret_cast<detail::OpResultImpl *>(impl_); return reinterpret_cast<detail::OpResultImpl *>(impl_);
} }
bool OpResult::operator==(const OpResult &other) const {
return impl_ == other.impl_;
}
detail::ValueImpl *OpResult::value_impl() const {
IR_ENFORCE(impl_, "Can't use value_impl() interface while value is null.");
return impl_;
}
uint32_t OpResult::GetValidInlineIndex(uint32_t index) { uint32_t OpResult::GetValidInlineIndex(uint32_t index) {
uint32_t max_inline_index = uint32_t max_inline_index =
ir::detail::OpResultImpl::GetMaxInlineResultIndex(); ir::detail::OpResultImpl::GetMaxInlineResultIndex();
......
...@@ -192,8 +192,12 @@ class IR_API OpResult : public Value { ...@@ -192,8 +192,12 @@ class IR_API OpResult : public Value {
uint32_t GetResultIndex() const; uint32_t GetResultIndex() const;
bool operator==(const OpResult &other) const;
friend Operation; friend Operation;
detail::ValueImpl *value_impl() const;
private: private:
static uint32_t GetValidInlineIndex(uint32_t index); static uint32_t GetValidInlineIndex(uint32_t index);
...@@ -209,4 +213,5 @@ struct hash<ir::Value> { ...@@ -209,4 +213,5 @@ struct hash<ir::Value> {
return std::hash<const ir::detail::ValueImpl *>()(obj.impl_); return std::hash<const ir::detail::ValueImpl *>()(obj.impl_);
} }
}; };
} // namespace std } // namespace std
...@@ -84,12 +84,26 @@ class TestPybind(unittest.TestCase): ...@@ -84,12 +84,26 @@ class TestPybind(unittest.TestCase):
matmul_op.result(0).set_stop_gradient(True) matmul_op.result(0).set_stop_gradient(True)
self.assertEqual(matmul_op.result(0).get_stop_gradient(), True) self.assertEqual(matmul_op.result(0).get_stop_gradient(), True)
# test opresult hash
result_set = set() result_set = set()
for opresult in matmul_op.results(): for opresult in matmul_op.results():
result_set.add(opresult) result_set.add(opresult)
# test opresult hash and hash(opresult) == hash(operesult)
# self.assertTrue(add_op.operands()[0].source() in result_set) self.assertTrue(add_op.operands()[0].source() in result_set)
# self.assertEqual(add_op.operands_source()[0] , matmul_op.results()[0],) # test value hash and hash(value) == hash(operesult)
self.assertTrue(add_op.operands_source()[0] in result_set)
# test value == value
self.assertEqual(
add_op.operands_source()[0], add_op.operands_source()[0]
)
# test value == opresult
self.assertEqual(add_op.operands_source()[0], matmul_op.results()[0])
# test opresult == value
self.assertEqual(
add_op.operands()[0].source(), add_op.operands_source()[0]
)
# test opresult == opresult
self.assertEqual(add_op.operands()[0].source(), matmul_op.results()[0])
self.assertEqual( self.assertEqual(
tanh_op.operands()[0].source().get_defining_op().name(), "pd.add" tanh_op.operands()[0].source().get_defining_op().name(), "pd.add"
...@@ -100,10 +114,6 @@ class TestPybind(unittest.TestCase): ...@@ -100,10 +114,6 @@ class TestPybind(unittest.TestCase):
tanh_op.operands()[0].source().get_defining_op().name(), "pd.matmul" tanh_op.operands()[0].source().get_defining_op().name(), "pd.matmul"
) )
self.assertEqual(
tanh_op.operands()[0].source().get_defining_op(),
tanh_op.operands_source()[0].get_defining_op(),
)
self.assertEqual(add_op.result(0).use_empty(), True) self.assertEqual(add_op.result(0).use_empty(), True)
def test_type(self): def test_type(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册