diff --git a/lite/core/mir/pattern_matcher.h b/lite/core/mir/pattern_matcher.h index 42d1b3fe555a97ba4168e205217867e35c4b0894..47a0a30b5667ddc97b3783ab9edbab04281528a4 100644 --- a/lite/core/mir/pattern_matcher.h +++ b/lite/core/mir/pattern_matcher.h @@ -157,16 +157,8 @@ struct PMNode { template PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { - asserts_.push_back([=](const Node* x) { - if (x && x->IsStmt()) { - auto* op_info = x->stmt()->op_info(); - bool cond = (op_info->HasAttr(attr_name) && - op_info->GetAttr(attr_name) == attr); - return cond; - } - return false; - }); - return this; + return assert_op_attr_satisfied( + attr_name, [=](const T& src) { return src == attr; }); } private: