diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 732e31d55b267d030e79b447ab402c26736d72c6..b7cba781007f4bc7efc06ac77a94e531bfec94ef 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2412,6 +2412,23 @@ PDNode *patterns::OrphanedBfloat16::operator()() { return next_op; } +PDNode *patterns::UnsupportedBfloat16::operator()() { + auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); + prev_op->assert_more([&](Node *node) { + return node->Op()->HasAttr("mkldnn_data_type") == false; + }); + auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput(); + + auto *op = pattern->NewNode(op_repr())->assert_is_op(); + op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") == + "bfloat16"; + }); + prev_op->LinksTo({prev_out}); + op->LinksFrom({prev_out}); + return op; +} + PDNode *patterns::LastBfloat16Ops::operator()() { auto *op = pattern->NewNode(op_repr())->assert_is_op(); op->assert_more([&](Node *node) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index b15a75312dd24e0892dfc74040cac31181d2577c..7d143129ebd346f6af1c2637566094076306d63d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1416,6 +1416,16 @@ struct OrphanedBfloat16 : public PatternBase { PATTERN_DECL_NODE(next_op); }; +struct UnsupportedBfloat16 : public PatternBase { + UnsupportedBfloat16(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "unsupported_bfloat16") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(prev_op); + PATTERN_DECL_NODE(prev_out); + PATTERN_DECL_NODE(op); +}; + struct LastBfloat16Ops : public PatternBase { LastBfloat16Ops(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "last_bfloat16_ops") {} diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc index 531a04e1a0d4c11799e8dea520faed447de4e808..0f9edeba525b02c1512df0e756c1a03233f3fc5b 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc @@ -71,10 +71,31 @@ void CPUBfloat16PlacementPass::RemoveOrphanedOperators( gpd(graph, handler); } +void CPUBfloat16PlacementPass::RemoveUnsupportedOperators( + ir::Graph* graph, int* bfloat16_operators) const { + // now quantize is supported FP32 only, so try to find + // bfloat16 operator that input type is not FP32 + GraphPatternDetector gpd; + patterns::UnsupportedBfloat16 unsupported_bfloat16_pattern{ + gpd.mutable_pattern(), "unsupported_bfloat16"}; + unsupported_bfloat16_pattern(); + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(prev_out, prev_out, unsupported_bfloat16_pattern); + GET_IR_NODE_FROM_SUBGRAPH(op, op, unsupported_bfloat16_pattern); + if ((prev_out->Var()->GetDataType() != proto::VarType::FP32)) { + op->Op()->SetAttr("mkldnn_data_type", std::string("float32")); + bfloat16_operators--; + } + }; + gpd(graph, handler); +} + void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const { int bfloat16_operators = 0; SetMkldnnDataType(graph, &bfloat16_operators); RemoveOrphanedOperators(graph, &bfloat16_operators); + RemoveUnsupportedOperators(graph, &bfloat16_operators); PrettyLogDetail("--- marked %d operators to bfloat16 ", bfloat16_operators); } diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h index 53b97f0e9726aacf86f6f71d3382ab25241e3cdb..facc4c4c5522122f7edb1cf9bdea73416e1deaee 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h @@ -30,6 +30,9 @@ class CPUBfloat16PlacementPass : public Pass { void RemoveOrphanedOperators(ir::Graph* graph, int* bfloat16_operators) const; + void RemoveUnsupportedOperators(ir::Graph* graph, + int* bfloat16_operators) const; + void ApplyImpl(ir::Graph* graph) const override; };