From 030b23da86cfd26e072642e27ee50c8a23544b91 Mon Sep 17 00:00:00 2001 From: Tomasz Socha Date: Thu, 2 Jun 2022 09:35:10 +0200 Subject: [PATCH] Fix for Bfloat16 placement pass. (#43109) * Fix bfloat16 placement pass * Make it nicer * Fix leftovers * Style --- paddle/fluid/framework/ir/graph_pattern_detector.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index f7c1a68c82..ea101125b1 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2631,8 +2631,10 @@ PDNode *patterns::Bfloat16Placement::operator()( PDNode *patterns::OrphanedBfloat16::operator()() { auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); prev_op->assert_more([&](Node *node) { - return node->Op()->GetAttrIfExists("mkldnn_data_type") == - "float32"; + bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type"); + bool data_type_is_fp32 = node->Op()->GetAttrIfExists( + "mkldnn_data_type") == "float32"; + return data_type_is_missing || data_type_is_fp32; }); auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput(); @@ -2645,8 +2647,10 @@ PDNode *patterns::OrphanedBfloat16::operator()() { auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); next_op->assert_more([&](Node *node) { - return node->Op()->GetAttrIfExists("mkldnn_data_type") == - "float32"; + bool data_type_is_missing = !node->Op()->HasAttr("mkldnn_data_type"); + bool data_type_is_fp32 = node->Op()->GetAttrIfExists( + "mkldnn_data_type") == "float32"; + return data_type_is_missing || data_type_is_fp32; }); prev_op->LinksTo({prev_out}); -- GitLab