diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 6949e4d078c0cd3af700a29307633846b08118d0..8c4965fc4023556a8809f9986318095c8d505e97 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2441,11 +2441,13 @@ PDNode *patterns::Bfloat16Placement::operator()( if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; } + auto *op_in = pattern->NewNode(op_in_repr())->AsInput(); auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types); op->assert_more([&](Node *node) { return node->Op()->GetAttrIfExists("use_mkldnn") || node->Op()->Type() == "reshape2"; }); + op->LinksFrom({op_in}); return op; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 940f6b8561e48b69687180187f2a2ba1af55deb6..5b996a3ab918b0323b2fd95a6d868e1419ccfc24 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1446,6 +1446,7 @@ struct Bfloat16Placement : public PatternBase { PDNode* operator()( const std::unordered_set& bfloat16_enabled_op_types); + PATTERN_DECL_NODE(op_in); PATTERN_DECL_NODE(op); }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc index 0f9edeba525b02c1512df0e756c1a03233f3fc5b..d89891ec3c857b64bc51e8e1b63effde9ca4bb3d 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc @@ -41,8 +41,12 @@ void CPUBfloat16PlacementPass::SetMkldnnDataType( auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_placement_pattern); GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_placement_pattern); + // Only float input can be converted to bfloat16 + if (op_in->Var()->GetDataType() != proto::VarType::FP32) return; + if ((op->Op()->HasAttr("mkldnn_data_type") || op->Op()->HasProtoAttr("mkldnn_data_type")) && !platform::HasOpINT8DataType(op->Op())) { diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc index 28a45f36fb71d7aa5e13128afa2e06301dbbcef9..e3ef7b7af05d2945f199f0ccaf19c39a8416bf37 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc @@ -68,7 +68,7 @@ ProgramDesc BuildProgramDesc() { for (auto& v : std::vector({"a", "b", "c", "f", "g", "h", "k", "l", "m", "n", "o", "p", "r", "s"})) { - prog.MutableBlock(0)->Var(v); + prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32); } SetOp(&prog, "concat", "concat1", {"a", "b"}, {"c"}); @@ -86,9 +86,8 @@ ProgramDesc BuildProgramDesc() { } void MainTest(std::initializer_list bfloat16_enabled_op_types, - unsigned expected_bfloat16_data_type_count) { - auto prog = BuildProgramDesc(); - + unsigned expected_bfloat16_data_type_count, + const ProgramDesc& prog) { std::unique_ptr graph(new ir::Graph(prog)); auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass"); @@ -110,8 +109,8 @@ void MainTest(std::initializer_list bfloat16_enabled_op_types, EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count); } -void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) { - auto prog = BuildProgramDesc(); +void DefaultAttrTest(unsigned expected_bfloat16_data_type_count, + const ProgramDesc& prog) { std::unique_ptr graph(new ir::Graph(prog)); auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass"); graph.reset(pass->Apply(graph.release())); @@ -128,15 +127,39 @@ void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) { } TEST(Bfloat16PlacementPass, enable_all) { - MainTest({"conv2d", "pool2d", "gelu", "concat", "sum"}, 8); + MainTest({"conv2d", "pool2d", "gelu", "concat", "sum"}, 8, + BuildProgramDesc()); } TEST(Bfloat16PlacementPass, enabled_conv_and_pool) { // 2 conv2d + 2 pool2 - 1 orphaned conv2d - MainTest({"conv2d", "pool2d"}, 3); + MainTest({"conv2d", "pool2d"}, 3, BuildProgramDesc()); +} + +TEST(Bfloat16PlacementPass, default_attr_value) { + DefaultAttrTest(10, BuildProgramDesc()); +} + +ProgramDesc BuildProgramDescWithDataType() { + ProgramDesc prog; + + for (auto& v : std::vector({"a", "b", "c", "d", "e"})) { + if (v == "a") { + prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::INT32); + } else { + prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32); + } + } + + SetOp(&prog, "conv2d", "conv1", {"a"}, {"b"}); + SetOp(&prog, "pool2d", "pool1", {"b"}, {"c"}); + SetOp(&prog, "concat", "concat1", {"c", "d"}, {"e"}); + return prog; } -TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(10); } +TEST(Bfloat16PlacementPass, check_data_types) { + DefaultAttrTest(2, BuildProgramDescWithDataType()); +} } // namespace ir } // namespace framework