diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index f7c1a68c826f0935fb6c551a744776679fc0bb69..ea101125b18d2d1ac01df69c4d54c28fd7243c7b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2631,8 +2631,10 @@ PDNode *patterns::Bfloat16Placement::operator()( PDNode *patterns::OrphanedBfloat16::operator()() { auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); prev_op->assert_more([&](Node *node) { - return node->Op()->GetAttrIfExists("mkldnn_data_type") == - "float32"; + bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type"); + bool data_type_is_fp32 = node->Op()->GetAttrIfExists( + "mkldnn_data_type") == "float32"; + return data_type_is_missing || data_type_is_fp32; }); auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput(); @@ -2645,8 +2647,10 @@ PDNode *patterns::OrphanedBfloat16::operator()() { auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); next_op->assert_more([&](Node *node) { - return node->Op()->GetAttrIfExists("mkldnn_data_type") == - "float32"; + bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type"); + bool data_type_is_fp32 = node->Op()->GetAttrIfExists( + "mkldnn_data_type") == "float32"; + return data_type_is_missing || data_type_is_fp32; }); prev_op->LinksTo({prev_out});