// Copyright (c) 2022 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/cinn/ir/ir_compare.h" #include #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_printer.h" namespace cinn { namespace ir { bool IrEqualVisitor::Compare(const Expr& lhs, const Expr& rhs) { if (lhs.get() == rhs.get()) { // the same object, including both are null return true; } if (!lhs.defined() || !rhs.defined()) { // someone invalid return false; VLOG(5) << "Not equal on Expr, someone not defined"; } bool equal = lhs->node_type() == rhs->node_type(); equal = equal && IRVisitorBase::Visit(&lhs, &rhs); if (!equal) { VLOG(5) << "Not equal on Expr, lhs:[type:" << kIrNodeTyReprs[static_cast(lhs->node_type())] << "]\n" << lhs << ", \nrhs[type:" << kIrNodeTyReprs[static_cast(rhs->node_type())] << "]\n" << rhs; } return equal; } bool IrEqualVisitor::Compare(const std::string& lhs, const std::string& rhs, bool allow_name_suffix_diff) { // if allow_name_suffix_diff=true then just compare the name prefix before the "_[0-9]+" auto common_len = 0; for (; common_len < lhs.size() && common_len < rhs.size(); ++common_len) { if (lhs[common_len] != rhs[common_len]) break; } auto is_endswith_index = [&common_len](const std::string& name) { const std::regex txt_regex("_\\d+"); return common_len == name.size() || std::regex_match(name.substr(common_len), txt_regex); }; bool equal = false; if (common_len == lhs.size() && common_len == rhs.size()) { equal = true; } else { equal = false; if (allow_name_suffix_diff) { equal = is_endswith_index(lhs) && is_endswith_index(rhs); } } if (!equal) { VLOG(5) << "Not euqal on name, lhs=" << lhs << ", rhs=" << rhs; } return equal; } bool IrEqualVisitor::Compare(const std::map& lhs, const std::map& rhs) { if (lhs.size() != rhs.size()) { VLOG(6) << "Not equal on attrs, lhs size=" << lhs.size() << ", rhs size=" << rhs.size(); return false; } for (auto&& kv : lhs) { auto opposite = rhs.find(kv.first); if (opposite == rhs.end() || kv.second != opposite->second) { VLOG(6) << "Not equal at attr key=" << kv.first; return false; } } return true; } template bool IrEqualVisitor::Compare(const std::vector& lhs, const std::vector& rhs) { if (lhs.size() != rhs.size()) { VLOG(6) << "Not equal on repeated fields, lhs size=" << lhs.size() << ", rhs size=" << rhs.size(); return false; } for (auto i = 0; i < lhs.size(); ++i) { if (!Compare(lhs.at(i), rhs.at(i))) { VLOG(6) << "Not equal on repeated fields at index=" << i; return false; } } return true; } #define PRIMITIVE_TYPE_IMPL(op__) \ bool IrEqualVisitor::Visit(const op__* lhs, const Expr* other) { \ auto* rhs = other->As(); \ return lhs->value == rhs->value; \ } #define UNARY_OP_IMPL(op__) \ bool IrEqualVisitor::Visit(const op__* lhs, const Expr* other) { \ auto* rhs = other->As(); \ return Compare(lhs->v(), rhs->v()); \ } #define BINARY_OP_IMPL(op__) \ bool IrEqualVisitor::Visit(const op__* lhs, const Expr* other) { \ auto* rhs = other->As(); \ return Compare(lhs->a(), rhs->a()) && Compare(lhs->b(), rhs->b()); \ } NODETY_PRIMITIVE_TYPE_FOR_EACH(PRIMITIVE_TYPE_IMPL) NODETY_UNARY_OP_FOR_EACH(UNARY_OP_IMPL) NODETY_BINARY_OP_FOR_EACH(BINARY_OP_IMPL) #undef PRIMITIVE_TYPE_IMPL #undef UNARY_OP_IMPL #undef BINARY_OP_IMPL bool IrEqualVisitor::Visit(const Cast* lhs, const Expr* other) { auto* rhs = other->As(); return lhs->type() == rhs->type() && Compare(lhs->v(), rhs->v()); } bool IrEqualVisitor::Visit(const For* lhs, const Expr* other) { auto* rhs = other->As(); return lhs->for_type() == rhs->for_type() && Compare(lhs->loop_var, rhs->loop_var) && Compare(lhs->min, rhs->min) && Compare(lhs->extent, rhs->extent) && Compare(lhs->body, rhs->body); } bool IrEqualVisitor::Visit(const PolyFor* lhs, const Expr* other) { auto* rhs = other->As(); return lhs->for_type() == rhs->for_type() && Compare(lhs->iterator, rhs->iterator) && Compare(lhs->init, rhs->init) && Compare(lhs->condition, rhs->condition) && Compare(lhs->inc, rhs->inc) && Compare(lhs->body, rhs->body); } bool IrEqualVisitor::Visit(const Select* lhs, const Expr* other) { auto* rhs = other->As