diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 8787aa8a94a44c2c36868fea4b88ede5f91b19f4..5bb833f613529a81d5ae4e18fc5ad7cd1136354b 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -102,6 +102,8 @@ if(WITH_MKLDNN) pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(scale_matmul_fuse_pass inference DIR mkldnn) + pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn) + pass_library(cpu_bfloat16_pass inference DIR mkldnn) pass_library(fc_mkldnn_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn) @@ -162,4 +164,6 @@ endif() cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass) cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass) + cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass) + cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass) endif () diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 3d65fe595373fa98ba237f04134c75d4a60a7242..9c1eaa99a3ca04ddbeecab639d5587d5509e3f00 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1892,6 +1892,82 @@ PDNode *patterns::QuantizePlacement::operator()( return op; } +PDNode *patterns::Bfloat16Placement::operator()( + const std::unordered_set &bfloat16_enabled_op_types) { + std::unordered_set supported_op_types = + std::unordered_set(); + if (!bfloat16_enabled_op_types.empty()) { + supported_op_types = bfloat16_enabled_op_types; + } + auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types); + return op; +} + +PDNode *patterns::OrphanedBfloat16::operator()() { + auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); + prev_op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") == + "float32"; + }); + auto *prev_out = pattern->NewNode(prev_out_repr())->AsOutput(); + + auto *op = pattern->NewNode(op_repr())->assert_is_op(); + op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") == + "bfloat16"; + }); + auto *op_out = pattern->NewNode(op_out_repr())->AsOutput(); + + auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + next_op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") == + "float32"; + }); + + prev_op->LinksTo({prev_out}); + op->LinksFrom({prev_out}).LinksTo({op_out}); + next_op->LinksFrom({op_out}); + return next_op; +} + +PDNode *patterns::LastBfloat16Ops::operator()() { + auto *op = pattern->NewNode(op_repr())->assert_is_op(); + op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") == + "bfloat16"; + }); + auto *op_out = pattern->NewNode(op_out_repr())->AsOutput(); + + auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + next_op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") != + "bfloat16"; + }); + + op->LinksTo({op_out}); + next_op->LinksFrom({op_out}); + return next_op; +} + +PDNode *patterns::FirstBfloat16Ops::operator()() { + auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); + prev_op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") != + "bfloat16"; + }); + auto *op_in = pattern->NewNode(op_in_repr())->AsOutput(); + + auto *op = pattern->NewNode(op_repr())->assert_is_op(); + op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") == + "bfloat16"; + }); + + prev_op->LinksTo({op_in}); + op->LinksFrom({op_in}); + return op; +} + PDNode *patterns::MKLDNNInPlace::operator()() { const std::unordered_set &supported_op_types = { "abs", diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 0803265884165bc754489b18d07c0d277a4bd92b..053c1fe832b0088d2abdd3f8eb40a0042e5e2dfe 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1129,6 +1129,47 @@ struct QuantizePlacement : public PatternBase { PATTERN_DECL_NODE(op); }; +struct Bfloat16Placement : public PatternBase { + Bfloat16Placement(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "bfloat16_placement") {} + PDNode* operator()( + const std::unordered_set& bfloat16_enabled_op_types); + + PATTERN_DECL_NODE(op); +}; + +struct OrphanedBfloat16 : public PatternBase { + OrphanedBfloat16(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "orphaned_bfloat16") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(prev_op); + PATTERN_DECL_NODE(prev_out); + PATTERN_DECL_NODE(op); + PATTERN_DECL_NODE(op_out); + PATTERN_DECL_NODE(next_op); +}; + +struct LastBfloat16Ops : public PatternBase { + LastBfloat16Ops(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "last_bfloat16_ops") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(op); + PATTERN_DECL_NODE(op_out); + PATTERN_DECL_NODE(next_op); +}; + +struct FirstBfloat16Ops : public PatternBase { + FirstBfloat16Ops(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "first_bfloat16_ops") {} + PDNode* operator()(); + + PATTERN_DECL_NODE(prev_op); + PATTERN_DECL_NODE(op_in); + PATTERN_DECL_NODE(op); +}; + // Pattern used for enforcing inplace computation for in-place computation // supporting DNNL ops. softmax, batch_norm and layer_norm struct MKLDNNInPlace : public PatternBase { diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..df498865245fc8054f9521026e0b5cd6906b136f --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc @@ -0,0 +1,159 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h" + +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void UnlinkNodes(ir::Node* a, ir::Node* b) { + a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b), + a->outputs.end()); + b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a), + b->inputs.end()); +} + +void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), + "first_bfloat16_ops"}; + bfloat16_ops(); + int quantize_counter = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, bfloat16_ops); + GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops); + GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops); + + if (op->Op()->Type() != "conv2d" && prev_op->Op()->Type() != "quantize") { + VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); + auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc); + + // create a quantize op node + OpDesc q_desc; + q_desc.SetType("quantize"); + q_desc.SetInput("Input", std::vector({op_in->Name()})); + q_desc.SetOutput("Output", + std::vector({quantize_out_node->Name()})); + q_desc.SetAttr("Scale", 1.f); + q_desc.SetAttr("bfloat16", true); + q_desc.SetAttr("output_format", Has("data_layout") + ? Get("data_layout") + : "NCHW"); + auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. + + std::string op_input_name; + for (auto name : op->Op()->InputNames()) { + for (auto input_name : op->Op()->Input(name)) { + if (input_name == op_in->Name()) op_input_name = name; + } + } + + PADDLE_ENFORCE_NE( + op_input_name.empty(), true, + platform::errors::NotFound( + "Operator before operator should have input as op output")); + + op->Op()->SetInput(op_input_name, + std::vector({quantize_out_node->Name()})); + + UnlinkNodes(op_in, op); + IR_NODE_LINK_TO(op_in, quantize_op); + IR_NODE_LINK_TO(quantize_op, quantize_out_node); + IR_NODE_LINK_TO(quantize_out_node, op); + quantize_counter++; + } + }; + gpd(graph, handler); + PrettyLogDetail("--- added %d quantize op before bfloat16 op", + quantize_counter); +} + +void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), + "last_bfloat16_ops"}; + bfloat16_ops(); + int force_fp32_counter = 0, dequantize_counter = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops); + GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops); + GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, bfloat16_ops); + + if ((op->Op()->HasAttr("force_fp32_output") || + op->Op()->HasProtoAttr("force_fp32_output")) && + !op->Op()->GetAttrIfExists("fuse_residual_connection")) { + op->Op()->SetAttr("force_fp32_output", true); + force_fp32_counter++; + } else if (op->Op()->Type() != "prior_box") { + // Create dequantize input variable + VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); + auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc); + + // create a dequantize op node for output. + OpDesc deq_desc; + deq_desc.SetType("dequantize"); + deq_desc.SetInput("Input", + std::vector({dequantize_in_node->Name()})); + deq_desc.SetOutput("Output", std::vector({op_out->Name()})); + deq_desc.SetAttr("Scale", 1.0f); + auto dequantize_op = g->CreateOpNode(&deq_desc); + + std::string op_output_name; + for (auto name : op->Op()->OutputNames()) { + for (auto output_name : op->Op()->Output(name)) { + if (output_name == op_out->Name()) op_output_name = name; + } + } + + PADDLE_ENFORCE_NE( + op_output_name.empty(), true, + platform::errors::NotFound( + "Operator after operator should have input as op output")); + + op->Op()->SetOutput(op_output_name, std::vector( + {dequantize_in_node->Name()})); + + UnlinkNodes(op, op_out); + IR_NODE_LINK_TO(op, dequantize_in_node); + IR_NODE_LINK_TO(dequantize_in_node, dequantize_op); + IR_NODE_LINK_TO(dequantize_op, op_out); + dequantize_counter++; + } + }; + gpd(graph, handler); + PrettyLogDetail("--- added %d dequantize op and used %d force_fp32_output", + dequantize_counter, force_fp32_counter); +} + +void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const { + SetInputDataType(graph); + SetOutputDataType(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(cpu_bfloat16_pass, paddle::framework::ir::CPUBFloat16Pass); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..3a7271f7ddc59a2bdcab8457bc34d5c5c6397268 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class CPUBFloat16Pass : public Pass { + protected: + void SetInputDataType(ir::Graph* graph) const; + void SetOutputDataType(ir::Graph* graph) const; + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..15109db98321343e73fb0c3839e4f7ddf2490948 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h" +#include "paddle/fluid/framework/naive_executor.h" +#include "paddle/fluid/imperative/type_defs.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, + const std::vector& inputs, + const std::vector& outputs, bool use_mkldnn, + const std::string& mkldnn_data_type = "float32", + const bool force_fp32_output = false) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("name", name); + + if (type == "conv2d") { + op->SetInput("Input", {inputs[0]}); + op->SetOutput("Output", {outputs[0]}); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); + op->SetAttr("force_fp32_output", force_fp32_output); + } else if (type == "pool2d" || type == "transpose2" || type == "reshape2" || + type == "dropout") { + op->SetInput("X", {inputs[0]}); + op->SetOutput("Out", {outputs[0]}); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); + } else if (type == "fc") { + op->SetInput("Input", {inputs[0]}); + op->SetOutput("Out", {outputs[0]}); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); + } else if (type == "concat") { + op->SetInput("X", inputs); + op->SetOutput("Out", outputs); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); + } else if (type == "matmul" || type == "elementwise_add") { + op->SetInput("X", {inputs[0]}); + if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); + op->SetOutput("Out", {outputs[0]}); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); + } +} + +void PreparePass(std::unique_ptr* graph, const ProgramDesc& prog, + const std::initializer_list variable_names, + int* original_nodes_num, int* current_nodes_num) { + auto pass = PassRegistry::Instance().Get("cpu_bfloat16_pass"); + + graph->reset(pass->Apply(graph->release())); + + *original_nodes_num = (*graph)->Nodes().size(); + (*graph).reset(pass->Apply((*graph).release())); + *current_nodes_num = (*graph)->Nodes().size(); +} + +static const std::initializer_list variable_names{ + "z", "a", "b", "c", "d", "e", "f", "g", "h", "i"}; + +ProgramDesc BuildProgramDesc(bool use_mkldnn) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "dropout", "Dropout1", {"z"}, {"a"}, use_mkldnn, "float32"); + SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "pool2d", "Pool1", {"b"}, {"c"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "conv2d", "Conv1", {"c"}, {"d"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "dropout", "Dropout2", {"d"}, {"e"}, use_mkldnn, "float32"); + SetOp(&prog, "transpose2", "Transpose1", {"e"}, {"f"}, use_mkldnn, + "bfloat16"); + SetOp(&prog, "reshape2", "Reshape1", {"f"}, {"g"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "concat", "Concat1", {"g"}, {"h"}, use_mkldnn, "bfloat16"); + SetOp(&prog, "dropout", "Dropout3", {"h"}, {"i"}, use_mkldnn, "float32"); + + return prog; +} + +void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, + int transpose_count, int quant_count, int dequant_count, + int added_nodes_count) { + std::unique_ptr graph(new ir::Graph(prog)); + int original_nodes_num, current_nodes_num; + PreparePass(&graph, prog, variable_names, &original_nodes_num, + ¤t_nodes_num); + + int quantize_nodes_count = 0; + int dequantize_nodes_count = 0; + int conv2d_nodes_count = 0; + int pool2d_nodes_count = 0; + int transpose2_nodes_count = 0; + + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (op->Type() == "conv2d") { + conv2d_nodes_count++; + } else if (op->Type() == "pool2d") { + pool2d_nodes_count++; + } else if (op->Type() == "transpose2") { + transpose2_nodes_count++; + } else if (op->Type() == "quantize") { + quantize_nodes_count++; + } else if (op->Type() == "dequantize") { + dequantize_nodes_count++; + } + } + } + EXPECT_EQ(conv2d_nodes_count, conv_count); + EXPECT_EQ(pool2d_nodes_count, pool_count); + EXPECT_EQ(transpose2_nodes_count, transpose_count); + EXPECT_EQ(quantize_nodes_count, quant_count); + EXPECT_EQ(dequantize_nodes_count, dequant_count); + EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num); +} + +TEST(CpuQuantizePass, quantize) { + bool use_mkldnn = true; + // 1 quantize + 1 dequantize + int added_nodes = 2; + MainTest(BuildProgramDesc(use_mkldnn), 2, 1, 1, 1, 2, added_nodes); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(cpu_bfloat16_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d7a9c1107bbaac04a3a478014520a9b340b1d5f --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h" + +#include +#include + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void CPUBfloat16PlacementPass::SetMkldnnDataType( + ir::Graph* graph, int* bfloat16_operators) const { + const auto& op_types_list = + Get>("bfloat16_enabled_op_types"); + // set mkldnn_data_type to bfloat16 to all operators that are in + // bfloat16_enabled_op_types vector or they are included to Bfloat16Placement + // pattern + GraphPatternDetector gpd; + patterns::Bfloat16Placement bfloat16_placement_pattern{gpd.mutable_pattern(), + "bfloat16_placement"}; + bfloat16_placement_pattern(op_types_list); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_placement_pattern); + + if ((op->Op()->HasAttr("mkldnn_data_type") || + op->Op()->HasProtoAttr("mkldnn_data_type")) && + !platform::HasOpINT8DataType(op->Op())) { + op->Op()->SetAttr("mkldnn_data_type", std::string("bfloat16")); + (*bfloat16_operators)++; + } + }; + gpd(graph, handler); +} + +void CPUBfloat16PlacementPass::RemoveOrhanedOperators( + ir::Graph* graph, int* bfloat16_operators) const { + // find orphaned bfloat16 operator that is between two float32 operators + // revert mkldnn_data_type attr to float32 + GraphPatternDetector gpd; + patterns::OrphanedBfloat16 orphaned_bfloat16_pattern{gpd.mutable_pattern(), + "orphaned_bfloat16"}; + orphaned_bfloat16_pattern(); + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(op, op, orphaned_bfloat16_pattern); + + op->Op()->SetAttr("mkldnn_data_type", std::string("float32")); + bfloat16_operators--; + }; + gpd(graph, handler); +} + +void CPUBfloat16PlacementPass::ApplyImpl(ir::Graph* graph) const { + int bfloat16_operators = 0; + SetMkldnnDataType(graph, &bfloat16_operators); + RemoveOrhanedOperators(graph, &bfloat16_operators); + PrettyLogDetail("--- marked %d operators to bfloat16 ", + bfloat16_operators); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(cpu_bfloat16_placement_pass, + paddle::framework::ir::CPUBfloat16PlacementPass) + // a vector of operator type names with bfloat16 support ("conv2d" etc.) + // the second param is the default value for this vector + .DefaultPassAttr("bfloat16_enabled_op_types", + new std::unordered_set()); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..1911b1a3cb32a6a23585e8240c462aa84e8d869b --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { +/* + * Specifies which operators should be run on bfloat16. + */ +class CPUBfloat16PlacementPass : public Pass { + protected: + void SetMkldnnDataType(ir::Graph* graph, int* bfloat16_operators) const; + + void RemoveOrhanedOperators(ir::Graph* graph, int* bfloat16_operators) const; + + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9797a4bfcc0048083e059cb003746e3278a039b --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc @@ -0,0 +1,132 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, + const std::vector& inputs, + const std::vector& outputs, + const std::string& mkldnn_data_type = "float32") { + auto* op = prog->MutableBlock(0)->AppendOp(); + + op->SetType(type); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); + + if (type == "conv2d") { + op->SetAttr("name", name); + op->SetInput("Input", {inputs[0]}); + } else if (type == "relu") { + op->SetInput("X", inputs); + } else if (type == "concat") { + op->SetAttr("axis", 1); + op->SetInput("X", {inputs[0], inputs[1]}); + } else if (type == "pool2d") { + op->SetInput("X", {inputs[0]}); + } else { + FAIL() << "Unexpected operator type."; + } + op->SetOutput("Out", {outputs[0]}); +} + +// operator mkldnn_data_type +// --------------------------------------- +// (a,b)->concat->c float32 +// c->conv->f float32 +// f->relu->g float32 +// g->pool->h float32 +// h->conv->k float32 +// k->pool->l float32 +ProgramDesc BuildProgramDesc() { + ProgramDesc prog; + + for (auto& v : + std::vector({"a", "b", "c", "f", "g", "h", "k", "l"})) { + prog.MutableBlock(0)->Var(v); + } + + SetOp(&prog, "concat", "concat1", {"a", "b"}, {"c"}); + SetOp(&prog, "conv2d", "conv1", {"c"}, {"f"}); + SetOp(&prog, "relu", "relu1", {"f"}, {"g"}); + SetOp(&prog, "pool2d", "pool1", {"g"}, {"h"}); + SetOp(&prog, "conv2d", "conv2", {"h"}, {"k"}); + SetOp(&prog, "pool2d", "pool2", {"k"}, {"l"}); + + return prog; +} + +void MainTest(std::initializer_list bfloat16_enabled_op_types, + unsigned expected_bfloat16_data_type_count) { + auto prog = BuildProgramDesc(); + + std::unique_ptr graph(new ir::Graph(prog)); + + auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass"); + pass->Set("bfloat16_enabled_op_types", + new std::unordered_set(bfloat16_enabled_op_types)); + + graph.reset(pass->Apply(graph.release())); + + unsigned bfloat16_data_type_count = 0; + + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + if (platform::HasOpBFLOAT16DataType(node->Op())) { + ++bfloat16_data_type_count; + } + } + } + + EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count); +} + +void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) { + auto prog = BuildProgramDesc(); + std::unique_ptr graph(new ir::Graph(prog)); + auto pass = PassRegistry::Instance().Get("cpu_bfloat16_placement_pass"); + graph.reset(pass->Apply(graph.release())); + + unsigned bfloat16_data_type_count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + if (platform::HasOpBFLOAT16DataType(node->Op())) { + ++bfloat16_data_type_count; + } + } + } + EXPECT_EQ(bfloat16_data_type_count, expected_bfloat16_data_type_count); +} + +TEST(Bfloat16PlacementPass, enable_all) { + MainTest({"conv2d", "pool2d", "relu", "concat"}, 6); +} + +TEST(Bfloat16PlacementPass, enabled_conv_and_pool) { + // 2 conv2d + 2 pool2 - 1 orphaned conv2d + MainTest({"conv2d", "pool2d"}, 3); +} + +TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(0); } + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(cpu_bfloat16_placement_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 98a36a3308dc539ee5aecad9e71f50be310e584c..c19e77d2714bcfc18c2cf2a98511d31a97295daa 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -231,6 +231,10 @@ void CpuPassStrategy::EnableMkldnnQuantizer() { void CpuPassStrategy::EnableMkldnnBfloat16() { #ifdef PADDLE_WITH_MKLDNN + if (!use_mkldnn_bfloat16_) { + passes_.push_back("cpu_bfloat16_placement_pass"); + passes_.push_back("cpu_bfloat16_pass"); + } use_mkldnn_bfloat16_ = true; #else use_mkldnn_bfloat16_ = false; diff --git a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc index 29a86a35d7b26f41745907fb6bacf30506c027a0..a6c8f8656a4e252f1a1eedb6d67ca322f0747a66 100644 --- a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc @@ -48,6 +48,7 @@ class QuantOpKernel : public framework::OpKernel { const T* input_data = input->data(); bool is_negative = ctx.Attr("is_negative_input"); + bool bfloat16 = ctx.Attr("bfloat16"); std::string key = platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_data, is_negative, ctx.OutputName("Output")); @@ -74,7 +75,10 @@ class QuantOpKernel : public framework::OpKernel { src_md, engine, to_void_cast(input_data)); std::shared_ptr dst_md; - if (is_negative) { + if (bfloat16) { + platform::SetDstMemoryQuantized( + ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); + } else if (is_negative) { platform::SetDstMemoryQuantized(ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); } else { @@ -96,7 +100,11 @@ class QuantOpKernel : public framework::OpKernel { dst_memory = std::static_pointer_cast( dev_ctx.GetBlob(key_dst_mem)); auto place = ctx.GetPlace(); - if (is_negative) { + + if (bfloat16) { + dst_memory->set_data_handle( + output->mutable_data(place)); + } else if (is_negative) { dst_memory->set_data_handle(output->mutable_data(place)); } else { dst_memory->set_data_handle(output->mutable_data(place)); diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index 8924e21b46f49b0fd0ec72e6acc7463d7d574d6f..602fdc6ff67787ace488379a2730dad4b8ffe1b1 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -40,6 +40,8 @@ void QuantOpMaker::Make() { AddAttr("output_format", "Convert format to NHWC or NCHW during quantization.") .SetDefault("NHWC"); + AddAttr("bfloat16", "(bool, default false) Convert to bfloat16") + .SetDefault(false); AddComment(R"DOC(This op will quantize data from FP32 to INT8)DOC"); } diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 8fb66c6f34bd8453f1aceb731bb1cd94b8e75a69..b012a103ea3031efb381d7039b15e82b2af52bf7 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -443,6 +443,13 @@ inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) { op->GetAttrIfExists("use_quantizer")); } +inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) { + return op->GetAttrIfExists("mkldnn_data_type") == "bfloat16"; +} + +inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) { + return op->GetAttrIfExists("mkldnn_data_type") == "float32"; +} enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP }; } // namespace platform diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 9950eb9adc241ca5c82b4b0289dd57da4195e558..97056eca411f29e9a2c379cbcb2f88775242f692 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -184,6 +184,7 @@ void BindVarDsec(pybind11::module *m) { .value("FP16", pd::proto::VarType::FP16) .value("FP32", pd::proto::VarType::FP32) .value("FP64", pd::proto::VarType::FP64) + .value("BF16", pd::proto::VarType::BF16) .value("LOD_TENSOR", pd::proto::VarType::LOD_TENSOR) .value("SELECTED_ROWS", pd::proto::VarType::SELECTED_ROWS) .value("FEED_MINIBATCH", pd::proto::VarType::FEED_MINIBATCH)