diff --git a/paddle/cinn/hlir/framework/op_lowering_util.cc b/paddle/cinn/hlir/framework/op_lowering_util.cc index 06ec4488203621f1b45ebaab3e415fd686e0134c..77443cc86d025b1e94f746f49414584c69fe7601 100644 --- a/paddle/cinn/hlir/framework/op_lowering_util.cc +++ b/paddle/cinn/hlir/framework/op_lowering_util.cc @@ -825,6 +825,7 @@ bool CanbeInline(Node* node, } auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + for (auto consumer : consumers) { if (op_pattern_dict[consumer->op()] == framework::kReduction) { return false; diff --git a/paddle/cinn/ir/ir_visitor.cc b/paddle/cinn/ir/ir_visitor.cc index 83090fc9e75d6f4935ccfd95b49ac0126474a087..50d81b839bc4119ad9586d300a7255104b87e82e 100644 --- a/paddle/cinn/ir/ir_visitor.cc +++ b/paddle/cinn/ir/ir_visitor.cc @@ -16,6 +16,7 @@ #include +#include "paddle/cinn/ir/ir_compare.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/utils/string.h" @@ -25,8 +26,8 @@ namespace ir { bool operator==(Expr a, Expr b) { if (a.get() == b.get()) return true; - // TODO(Superjomn) implement with a more accurate one - return utils::GetStreamCnt(a) == utils::GetStreamCnt(b); + IrEqualVisitor cmp; + return cmp.Compare(a, b); } bool operator!=(Expr a, Expr b) { return !(a == b); } diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index 8af88595fd06ddec2ebeffc1d56a578956add1a4..4424b75cef179e89681991ac5356ed6b43d95372 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -112,6 +112,7 @@ const CinnCompiledObject &CinnCompiler::Compile( auto compiled_res = CompileGraph(graph, input_tensors, target, compiled_num, stream); + std::unique_lock guard(lock_); // double check cache_by_struct_ if (!cache_by_struct_.count(cur_key_by_struct)) {