未验证 提交 8ddf51ff 编写于 作者: X xiaoguoguo626807 提交者: GitHub

add eq and hash (#55909)

上级 275a8102
......@@ -176,7 +176,13 @@ void BindValue(py::module *m) {
.def("get_defining_op",
&Value::GetDefiningOp,
return_value_policy::reference)
.def("__eq__", &Value::operator==);
.def("__eq__", &Value::operator==)
.def("__eq__",
[](Value &self, OpResult &other) {
return self.impl() == other.value_impl();
})
.def("__hash__",
[](const Value &self) { return std::hash<ir::Value>{}(self); });
}
void BindOpOperand(py::module *m) {
......@@ -218,18 +224,27 @@ void BindOpResult(py::module *m) {
ir::ArrayAttribute::get(ir::IrContext::Instance(),
stop_gradients));
})
.def("get_stop_gradient", [](OpResult &self) {
auto *defining_op = self.owner();
if (defining_op->HasAttribute(kAttrStopGradients)) {
auto stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>()
.AsVector();
return stop_gradients[self.GetResultIndex()]
.dyn_cast<ir::BoolAttribute>()
.data();
} else {
return false;
}
.def("get_stop_gradient",
[](OpResult &self) {
auto *defining_op = self.owner();
if (defining_op->HasAttribute(kAttrStopGradients)) {
auto stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>()
.AsVector();
return stop_gradients[self.GetResultIndex()]
.dyn_cast<ir::BoolAttribute>()
.data();
} else {
return false;
}
})
.def("__eq__", &OpResult::operator==)
.def("__eq__",
[](OpResult &self, Value &other) {
return self.value_impl() == other.impl();
})
.def("__hash__", [](OpResult &self) {
return std::hash<ir::Value>{}(self.dyn_cast<ir::Value>());
});
}
......
......@@ -122,6 +122,15 @@ detail::OpResultImpl *OpResult::impl() const {
return reinterpret_cast<detail::OpResultImpl *>(impl_);
}
bool OpResult::operator==(const OpResult &other) const {
return impl_ == other.impl_;
}
detail::ValueImpl *OpResult::value_impl() const {
IR_ENFORCE(impl_, "Can't use value_impl() interface while value is null.");
return impl_;
}
uint32_t OpResult::GetValidInlineIndex(uint32_t index) {
uint32_t max_inline_index =
ir::detail::OpResultImpl::GetMaxInlineResultIndex();
......
......@@ -192,8 +192,12 @@ class IR_API OpResult : public Value {
uint32_t GetResultIndex() const;
bool operator==(const OpResult &other) const;
friend Operation;
detail::ValueImpl *value_impl() const;
private:
static uint32_t GetValidInlineIndex(uint32_t index);
......@@ -209,4 +213,5 @@ struct hash<ir::Value> {
return std::hash<const ir::detail::ValueImpl *>()(obj.impl_);
}
};
} // namespace std
......@@ -84,12 +84,26 @@ class TestPybind(unittest.TestCase):
matmul_op.result(0).set_stop_gradient(True)
self.assertEqual(matmul_op.result(0).get_stop_gradient(), True)
# test opresult hash
result_set = set()
for opresult in matmul_op.results():
result_set.add(opresult)
# self.assertTrue(add_op.operands()[0].source() in result_set)
# self.assertEqual(add_op.operands_source()[0] , matmul_op.results()[0],)
# test opresult hash and hash(opresult) == hash(operesult)
self.assertTrue(add_op.operands()[0].source() in result_set)
# test value hash and hash(value) == hash(operesult)
self.assertTrue(add_op.operands_source()[0] in result_set)
# test value == value
self.assertEqual(
add_op.operands_source()[0], add_op.operands_source()[0]
)
# test value == opresult
self.assertEqual(add_op.operands_source()[0], matmul_op.results()[0])
# test opresult == value
self.assertEqual(
add_op.operands()[0].source(), add_op.operands_source()[0]
)
# test opresult == opresult
self.assertEqual(add_op.operands()[0].source(), matmul_op.results()[0])
self.assertEqual(
tanh_op.operands()[0].source().get_defining_op().name(), "pd.add"
......@@ -100,10 +114,6 @@ class TestPybind(unittest.TestCase):
tanh_op.operands()[0].source().get_defining_op().name(), "pd.matmul"
)
self.assertEqual(
tanh_op.operands()[0].source().get_defining_op(),
tanh_op.operands_source()[0].get_defining_op(),
)
self.assertEqual(add_op.result(0).use_empty(), True)
def test_type(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册