From 781df300d044fc494a582bce0b68fa1c47097c41 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Tue, 23 Feb 2021 05:02:25 +0100 Subject: [PATCH] Unification of BF16 enablement process (#31034) * Unification of bfloat16 enablement process and refactor * Remove unnecessary function * Standardize the output name search --- .../framework/ir/graph_pattern_detector.cc | 91 +++++----- .../framework/ir/graph_pattern_detector.h | 40 +++-- .../framework/ir/mkldnn/cpu_bfloat16_pass.cc | 163 ++++++++---------- .../ir/mkldnn/cpu_bfloat16_pass_tester.cc | 47 +++-- .../ir/mkldnn/cpu_quantize_squash_pass.cc | 103 ++++++++++- .../ir/mkldnn/cpu_quantize_squash_pass.h | 10 ++ .../mkldnn/cpu_quantize_squash_pass_tester.cc | 67 ++++++- .../inference/api/paddle_pass_builder.cc | 1 + .../operators/mkldnn/requantize_mkldnn_op.cc | 3 +- 9 files changed, 341 insertions(+), 184 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 4de75de5ebb..a38f10ba408 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1829,9 +1829,8 @@ PDNode *patterns::OpDequant::operator()() { auto any_op = pattern->NewNode(any_op_repr()) ->assert_is_op() ->assert_more([&](Node *node) { - return (node->Op()->Type() == "matmul" || - node->Op()->Type() == "conv2d" || - node->Op()->Type() == "fc"); + return (node->Op()->HasAttr("force_fp32_output") || + node->Op()->HasProtoAttr("force_fp32_output")); }); auto dequant_in = pattern->NewNode(dequant_in_repr()) ->assert_is_op_input("dequantize", "Input"); @@ -1865,6 +1864,44 @@ PDNode *patterns::DequantScale::operator()() { return scale_out; } +PDNode *patterns::ScaleQuant::operator()() { + auto scale_in = pattern->NewNode(scale_in_repr()) + ->AsInput() + ->assert_is_op_input("scale", "X"); + auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale"); + + auto quant_in = pattern->NewNode(quant_in_repr()) + ->AsInput() + ->assert_is_op_input("quantize", "Input"); + auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize"); + + scale_op->LinksFrom({scale_in}).LinksTo({quant_in}); + quant_op->LinksFrom({quant_in}); + + return quant_op; +} + +PDNode *patterns::QuantConv::operator()() { + auto quant_in = pattern->NewNode(quant_in_repr()) + ->AsInput() + ->assert_is_op_input("quantize", "Input"); + auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize"); + + auto conv_in = pattern->NewNode(conv_in_repr()) + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); + conv_op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("mkldnn_data_type") == + "bfloat16"; + }); + + quant_op->LinksFrom({quant_in}).LinksTo({conv_in}); + conv_op->LinksFrom({conv_in}); + + return quant_op; +} + PDNode *patterns::ScaleMatmul::operator()() { auto scale_in = pattern->NewNode(scale_in_repr()) ->AsInput() @@ -2191,10 +2228,11 @@ PDNode *patterns::QuantizePlacement::operator()( PDNode *patterns::Bfloat16Placement::operator()( const std::unordered_set &bfloat16_enabled_op_types) { std::unordered_set supported_op_types = - std::unordered_set( - {"concat", "conv2d", "conv2d_transpose", "elementwise_add", - "elementwise_mul", "fc", "fusion_gru", "gelu", "layer_norm", - "matmul", "pool2d", "reshape2", "softmax", "sum", "transpose2"}); + std::unordered_set({"concat", "conv2d", "conv2d_transpose", + "elementwise_add", "elementwise_mul", + "fc", "fusion_gru", "gelu", "layer_norm", + "matmul", "pool2d", "relu", "reshape2", + "softmax", "sum", "transpose2"}); if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; } @@ -2240,25 +2278,12 @@ PDNode *patterns::LastBfloat16Ops::operator()() { "bfloat16"; }); auto *op_out = pattern->NewNode(op_out_repr())->AsOutput(); - - auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op(); - next_op->assert_more([&](Node *node) { - return node->Op()->GetAttrIfExists("mkldnn_data_type") != - "bfloat16"; - }); - op->LinksTo({op_out}); - next_op->LinksFrom({op_out}); - return next_op; + return op_out; } PDNode *patterns::FirstBfloat16Ops::operator()() { - auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); - prev_op->assert_more([&](Node *node) { - return node->Op()->GetAttrIfExists("mkldnn_data_type") != - "bfloat16"; - }); - auto *op_in = pattern->NewNode(op_in_repr())->AsOutput(); + auto *op_in = pattern->NewNode(op_in_repr())->AsInput(); auto *op = pattern->NewNode(op_repr())->assert_is_op(); op->assert_more([&](Node *node) { @@ -2266,7 +2291,6 @@ PDNode *patterns::FirstBfloat16Ops::operator()() { "bfloat16"; }); - prev_op->LinksTo({op_in}); op->LinksFrom({op_in}); return op; } @@ -2280,27 +2304,6 @@ PDNode *patterns::DuplicatedInputs::operator()() { return op; } -PDNode *patterns::UnnecessaryReorders::operator()() { - auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); - prev_op->assert_more([&](Node *node) { - return node->Op()->GetAttrIfExists("mkldnn_data_type") == - "bfloat16"; - }); - - auto *quant_in = pattern->NewNode(quant_in_repr()) - ->assert_is_op_input("quantize", "Input"); - - auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize"); - - auto *quant_out = pattern->NewNode(quant_out_repr()) - ->assert_is_op_output("quantize", "Output"); - - prev_op->LinksTo({quant_in}); - quant_op->LinksFrom({quant_in}).LinksTo({quant_out}); - - return quant_out; -} - PDNode *patterns::MKLDNNInPlace::operator()() { const std::unordered_set &supported_op_types = { "abs", diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index f9b6e0ef9c9..2e518c1d4df 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1135,11 +1135,36 @@ struct DequantScale : public PatternBase { PATTERN_DECL_NODE(dequant_op); PATTERN_DECL_NODE(dequant_out); - PATTERN_DECL_NODE(scale_op); PATTERN_DECL_NODE(scale_out); }; +// Scale + Quantize +struct ScaleQuant : public PatternBase { + ScaleQuant(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "scale_quant") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(scale_in); + PATTERN_DECL_NODE(scale_op); + PATTERN_DECL_NODE(quant_in); + PATTERN_DECL_NODE(quant_op); +}; + +// Quantize + Conv2d +struct QuantConv : public PatternBase { + QuantConv(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "quant_conv") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(quant_in); + PATTERN_DECL_NODE(quant_op); + PATTERN_DECL_NODE(conv_in); + PATTERN_DECL_NODE(conv_op); +}; + // Scale + Matmul struct ScaleMatmul : public PatternBase { ScaleMatmul(PDPattern* pattern, const std::string& name_scope) @@ -1338,7 +1363,6 @@ struct LastBfloat16Ops : public PatternBase { PATTERN_DECL_NODE(op); PATTERN_DECL_NODE(op_out); - PATTERN_DECL_NODE(next_op); }; struct FirstBfloat16Ops : public PatternBase { @@ -1346,7 +1370,6 @@ struct FirstBfloat16Ops : public PatternBase { : PatternBase(pattern, name_scope, "first_bfloat16_ops") {} PDNode* operator()(); - PATTERN_DECL_NODE(prev_op); PATTERN_DECL_NODE(op_in); PATTERN_DECL_NODE(op); }; @@ -1360,17 +1383,6 @@ struct DuplicatedInputs : public PatternBase { PATTERN_DECL_NODE(op); }; -struct UnnecessaryReorders : public PatternBase { - UnnecessaryReorders(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "unnecessary_reorders") {} - PDNode* operator()(); - - PATTERN_DECL_NODE(prev_op); - PATTERN_DECL_NODE(quant_in); - PATTERN_DECL_NODE(quant_op); - PATTERN_DECL_NODE(quant_out); -}; - // Pattern used for enforcing inplace computation for in-place computation // supporting DNNL ops. softmax, batch_norm and layer_norm struct MKLDNNInPlace : public PatternBase { diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc index 9658d604520..5f9aefc1e7a 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc @@ -12,12 +12,10 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h" #include -#include #include #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { @@ -33,8 +31,38 @@ void UnlinkNodes(ir::Node* a, ir::Node* b) { b->inputs.end()); } +// Checking whether a reorder from FP32 to BF16 should be added before the input +// to the operator +bool IsPermittedInputName(const std::string& input_name) { + // Only the inputs listed in \"permitted_names\" requires quanitization before + // the bfloat16 operator. Other inputs, such as Filter and Bias are reordered + // in the kernel. + const std::vector permitted_names = {"X", "Y", "Input", + "ResidualData"}; + return (std::find(permitted_names.begin(), permitted_names.end(), + input_name) != permitted_names.end()); +} + +// Checking whether a reorder from BF16 to FP32 should be added after the output +// to the operator +bool IsPermittedOutputName(const std::string& output_name) { + // XShape is output in transpose2 and reshape2 operators used to store the + // shape and lod of X. So this output do not need dequantize before. + return (output_name != "XShape"); +} + void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in, int* quantize_counter) { + std::vector input_names; + + // Find the name of the input linking op to op_in + for (auto name : op->Op()->InputNames()) + for (auto input_name : op->Op()->Input(name)) + if (input_name == op_in->Name() && IsPermittedInputName(name)) + input_names.push_back(name); + + if (input_names.empty()) return; + VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc); @@ -44,23 +72,12 @@ void AddQuantize(Graph* g, ir::Node* op, ir::Node* op_in, q_desc.SetOutput("Output", std::vector({quantize_out_node->Name()})); q_desc.SetAttr("Scale", 1.f); + q_desc.SetAttr("Shift", 0.0f); q_desc.SetAttr("bfloat16", true); q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout") ? op->Op()->GetAttr("data_layout") : std::string("NCHW")); - auto quantize_op = g->CreateOpNode(&q_desc); - - std::vector input_names; - for (auto name : op->Op()->InputNames()) { - for (auto input_name : op->Op()->Input(name)) { - if (input_name == op_in->Name()) input_names.push_back(name); - } - } - - PADDLE_ENFORCE_NE( - input_names.empty(), true, - platform::errors::NotFound( - "Operator before operator should have input as op output")); + auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. for (auto name = input_names.begin(); name < input_names.end(); name++) op->Op()->SetInput(*name, @@ -99,11 +116,12 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) { q_desc.SetOutput("Output", std::vector({quantize_out_node_names[i]})); q_desc.SetAttr("Scale", 1.f); + q_desc.SetAttr("Shift", 0.0f); q_desc.SetAttr("bfloat16", true); q_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout") ? op->Op()->GetAttr("data_layout") : std::string("NCHW")); - auto quantize_op = g->CreateOpNode(&q_desc); + auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. UnlinkNodes(inputs[i], op); IR_NODE_LINK_TO(inputs[i], quantize_op); @@ -115,6 +133,9 @@ void AddQuantizes(Graph* g, ir::Node* op, int* quantize_counter) { op->Op()->SetInput("X", quantize_out_node_names); } +// Operators like Concat and Sum have a single input name X, which actually +// consists of multiple inputs. Such operators require a different way to find +// pattern and add quantize ops. void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) { GraphPatternDetector gpd; patterns::DuplicatedInputs duplicated_inputs{gpd.mutable_pattern(), @@ -128,38 +149,8 @@ void AddReoderBeforeDuplicatedInputs(ir::Graph* graph, int* quantize_counter) { gpd(graph, handler); } -void RemoveUnnecessaryReorders(ir::Graph* graph, int* quantize_counter) { - GraphPatternDetector gpd; - patterns::UnnecessaryReorders unnecessary_reorders{gpd.mutable_pattern(), - "unnecessary_reorders"}; - unnecessary_reorders(); - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, unnecessary_reorders); - GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, unnecessary_reorders); - GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, unnecessary_reorders); - GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, unnecessary_reorders); - - std::string op_output_name; - for (auto name : prev_op->Op()->OutputNames()) - for (auto output_name : prev_op->Op()->Output(name)) - if (output_name == quant_in->Name()) op_output_name = name; - - PADDLE_ENFORCE_NE( - op_output_name.empty(), true, - platform::errors::NotFound( - "Operator before operator should have input as op output")); - - prev_op->Op()->SetOutput(op_output_name, - std::vector({quant_out->Name()})); - - IR_NODE_LINK_TO(prev_op, quant_out); - GraphSafeRemoveNodes(graph, {quant_in, quant_op}); - (*quantize_counter)--; - }; - gpd(graph, handler); -} - +// Adding quantize ops before all operators except Concat and Sum, which have +// already been handled in AddReoderBeforeDuplicatedInputs void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) { GraphPatternDetector gpd; patterns::FirstBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), @@ -167,12 +158,9 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) { bfloat16_ops(); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op_in, op_in, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops); - auto prev_op_type = prev_op->Op()->Type(); - if (op->Op()->Type() != "conv2d" && prev_op_type != "quantize" && - prev_op_type != "sum" && prev_op_type != "concat") { + if (op->Op()->Type() != "sum" && op->Op()->Type() != "concat") { AddQuantize(g, op, op_in, quantize_counter); } }; @@ -182,9 +170,8 @@ void AddReoderBeforeSingleInputs(ir::Graph* graph, int* quantize_counter) { void CPUBFloat16Pass::SetInputDataType(ir::Graph* graph) const { int quantize_counter = 0; AddReoderBeforeDuplicatedInputs(graph, &quantize_counter); - RemoveUnnecessaryReorders(graph, &quantize_counter); AddReoderBeforeSingleInputs(graph, &quantize_counter); - PrettyLogDetail("--- added %d quantize op before bfloat16 op", + PrettyLogDetail("--- added %d quantize ops before bfloat16 op", quantize_counter); } @@ -193,55 +180,51 @@ void CPUBFloat16Pass::SetOutputDataType(ir::Graph* graph) const { patterns::LastBfloat16Ops bfloat16_ops{gpd.mutable_pattern(), "last_bfloat16_ops"}; bfloat16_ops(); - int force_fp32_counter = 0, dequantize_counter = 0; + int dequantize_counter = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { GET_IR_NODE_FROM_SUBGRAPH(op, op, bfloat16_ops); GET_IR_NODE_FROM_SUBGRAPH(op_out, op_out, bfloat16_ops); - GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, bfloat16_ops); - if ((op->Op()->HasAttr("force_fp32_output") || - op->Op()->HasProtoAttr("force_fp32_output")) && - !op->Op()->GetAttrIfExists("fuse_residual_connection")) { - op->Op()->SetAttr("force_fp32_output", true); - force_fp32_counter++; - } else if (op->Op()->Type() != "prior_box") { - VarDesc dequantize_out_desc(patterns::PDNodeName("dequantize", "out")); - auto* dequantize_out_node = g->CreateVarNode(&dequantize_out_desc); + + if (op->Op()->Type() != "prior_box") { + // Find the name of the output linking op to op_out + std::vector output_names; + for (auto name : op->Op()->OutputNames()) + for (auto output_name : op->Op()->Output(name)) + if (output_name == op_out->Name() && IsPermittedOutputName(name)) + output_names.push_back(name); + + if (output_names.empty()) return; + + VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); + auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc); OpDesc deq_desc; deq_desc.SetType("dequantize"); - deq_desc.SetInput("Input", std::vector({op_out->Name()})); - deq_desc.SetOutput( - "Output", std::vector({dequantize_out_node->Name()})); + deq_desc.SetInput("Input", + std::vector({dequantize_in_node->Name()})); + deq_desc.SetOutput("Output", std::vector({op_out->Name()})); deq_desc.SetAttr("Scale", 1.0f); - auto dequantize_op = g->CreateOpNode(&deq_desc); - - std::string next_op_input_name; - for (auto name : next_op->Op()->InputNames()) { - for (auto input_name : next_op->Op()->Input(name)) { - if (input_name == op_out->Name()) next_op_input_name = name; - } - } - - PADDLE_ENFORCE_NE( - next_op_input_name.empty(), true, - platform::errors::NotFound( - "Operator before operator should have input as op output")); - - next_op->Op()->SetInput( - next_op_input_name, - std::vector({dequantize_out_node->Name()})); - UnlinkNodes(op_out, next_op); - IR_NODE_LINK_TO(op_out, dequantize_op); - IR_NODE_LINK_TO(dequantize_op, dequantize_out_node); - IR_NODE_LINK_TO(dequantize_out_node, next_op); + deq_desc.SetAttr("Shift", 0.0f); + auto dequantize_op = + g->CreateOpNode(&deq_desc); // OpDesc will be copied. + + for (auto name = output_names.begin(); name < output_names.end(); name++) + op->Op()->SetOutput( + *name, std::vector({dequantize_in_node->Name()})); + + UnlinkNodes(op, op_out); + IR_NODE_LINK_TO(op, dequantize_in_node); + IR_NODE_LINK_TO(dequantize_in_node, dequantize_op); + IR_NODE_LINK_TO(dequantize_op, op_out); + dequantize_counter++; } }; gpd(graph, handler); - PrettyLogDetail("--- added %d dequantize op and used %d force_fp32_output", - dequantize_counter, force_fp32_counter); + PrettyLogDetail("--- added %d dequantize ops after bfloat16 op", + dequantize_counter); } void CPUBFloat16Pass::ApplyImpl(ir::Graph* graph) const { diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc index ab8d3cbdfc0..f620b4c94fe 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc @@ -26,8 +26,7 @@ namespace ir { void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, const std::vector& inputs, const std::vector& outputs, bool use_mkldnn, - const std::string& mkldnn_data_type = "float32", - const bool force_fp32_output = false) { + const std::string& mkldnn_data_type = "float32") { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); op->SetAttr("use_mkldnn", use_mkldnn); @@ -37,7 +36,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetInput("Input", {inputs[0]}); op->SetOutput("Output", {outputs[0]}); op->SetAttr("mkldnn_data_type", mkldnn_data_type); - op->SetAttr("force_fp32_output", force_fp32_output); } else if (type == "pool2d" || type == "transpose2" || type == "reshape2" || type == "dropout") { op->SetInput("X", {inputs[0]}); @@ -47,7 +45,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetInput("Input", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); op->SetAttr("mkldnn_data_type", mkldnn_data_type); - op->SetAttr("force_fp32_output", force_fp32_output); } else if (type == "concat" || type == "sum") { op->SetInput("X", inputs); op->SetOutput("Out", outputs); @@ -58,7 +55,6 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); op->SetOutput("Out", {outputs[0]}); op->SetAttr("mkldnn_data_type", mkldnn_data_type); - if (type == "matmul") op->SetAttr("force_fp32_output", force_fp32_output); } else if (type == "layer_norm") { op->SetInput("X", {inputs[0]}); op->SetOutput("Y", {outputs[0]}); @@ -79,8 +75,8 @@ void PreparePass(std::unique_ptr* graph, const ProgramDesc& prog, *current_nodes_num = (*graph)->Nodes().size(); } -void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count, - int force_fp32_count, int added_nodes_count) { +void MainTest(const ProgramDesc& prog, const int& quant_count, + const int& dequant_count, const int& added_nodes_count) { std::unique_ptr graph(new ir::Graph(prog)); int original_nodes_num, current_nodes_num; PreparePass(&graph, prog, variable_names, &original_nodes_num, @@ -88,7 +84,6 @@ void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count, int quantize_nodes_count = 0; int dequantize_nodes_count = 0; - int force_fp32_nodes_count = 0; for (auto* node : graph->Nodes()) { if (node->IsOp()) { auto* op = node->Op(); @@ -96,16 +91,11 @@ void MainTest(const ProgramDesc& prog, int quant_count, int dequant_count, quantize_nodes_count++; } else if (op->Type() == "dequantize") { dequantize_nodes_count++; - } else if (op->Type() == "conv2d" || op->Type() == "matmul" || - op->Type() == "fc") { - if (op->GetAttrIfExists("force_fp32_output")) - force_fp32_nodes_count++; } } } EXPECT_EQ(quantize_nodes_count, quant_count); EXPECT_EQ(dequantize_nodes_count, dequant_count); - EXPECT_EQ(force_fp32_nodes_count, force_fp32_count); EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num); } @@ -125,9 +115,10 @@ ProgramDesc BuildProgramDescConv(bool use_mkldnn) { TEST(CpuBfloat16Pass, convolution) { bool use_mkldnn = true; - // 0 added + 1 force_fp32_output - int added_nodes = 0; - MainTest(BuildProgramDescConv(use_mkldnn), 0, 0, 1, added_nodes); + int quant_op = 3; + int dequant_op = 3; + int added_nodes = quant_op * 2 + dequant_op * 2; + MainTest(BuildProgramDescConv(use_mkldnn), quant_op, dequant_op, added_nodes); } ProgramDesc BuildProgramDescDoubleInput(bool use_mkldnn) { @@ -147,9 +138,11 @@ ProgramDesc BuildProgramDescDoubleInput(bool use_mkldnn) { TEST(CpuBfloat16Pass, double_input_ops) { bool use_mkldnn = true; - // 2 quant + 2 quant out - int added_nodes = 4; - MainTest(BuildProgramDescDoubleInput(use_mkldnn), 2, 0, 0, added_nodes); + int quant_op = 4; + int dequant_op = 3; + int added_nodes = quant_op * 2 + dequant_op * 2; + MainTest(BuildProgramDescDoubleInput(use_mkldnn), quant_op, dequant_op, + added_nodes); } ProgramDesc BuildProgramDescDuplicatedInput(bool use_mkldnn) { @@ -169,9 +162,11 @@ ProgramDesc BuildProgramDescDuplicatedInput(bool use_mkldnn) { TEST(CpuBfloat16Pass, duplicated_input_ops) { bool use_mkldnn = true; - // 3 quant + 3 quant out - int added_nodes = 6; - MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), 3, 0, 0, added_nodes); + int quant_op = 5; + int dequant_op = 3; + int added_nodes = quant_op * 2 + dequant_op * 2; + MainTest(BuildProgramDescDuplicatedInput(use_mkldnn), quant_op, dequant_op, + added_nodes); } ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) { @@ -193,9 +188,11 @@ ProgramDesc BuildProgramDescDoubleOutputs(bool use_mkldnn) { TEST(CpuBfloat16Pass, double_outputs_ops) { bool use_mkldnn = true; - // 3 dequant + 3 dequant out - int added_nodes = 6; - MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), 0, 3, 0, added_nodes); + int quant_op = 3; + int dequant_op = 3; + int added_nodes = quant_op * 2 + dequant_op * 2; + MainTest(BuildProgramDescDoubleOutputs(use_mkldnn), quant_op, dequant_op, + added_nodes); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index d6146f264ab..34668192f0b 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -255,14 +255,21 @@ void CPUQuantizeSquashPass::OpDequantSquash(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, op_dequant_pattern); if (dequant_in->outputs.size() == 1) { - auto output_name = "Out"; - if (any_op->Op()->Type() == "conv2d") { + if (any_op->Op()->Type() == "conv2d" || + any_op->Op()->Type() == "conv2d_transpose") { // do not squash if fuse residual connection is true // because residual fusion does not support force output with fp32 if (any_op->Op()->GetAttrIfExists("fuse_residual_connection")) return; - output_name = "Output"; } + // Find the name of the output linking any_op to dequant_in + std::string output_name; + for (auto name : any_op->Op()->OutputNames()) + for (auto out_name : any_op->Op()->Output(name)) + if (out_name == dequant_in->Name()) output_name = name; + + if (output_name.empty()) return; + any_op->Op()->SetAttr("force_fp32_output", true); any_op->Op()->SetOutput(output_name, std::vector({dequant_out->Name()})); @@ -363,10 +370,10 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { platform::errors::InvalidArgument( "Dequantize scale(%f) should have positive value.", dequant_scale)); - PADDLE_ENFORCE_GT(scale_scale, 0.0f, - platform::errors::InvalidArgument( - "Scale(%f) of scale op should have positive value.", - scale_scale)); + PADDLE_ENFORCE_NE( + scale_scale, 0.0f, + platform::errors::InvalidArgument( + "Scale(%f) should have a non-zero value", scale_scale)); dequant_op->Op()->SetAttr("Scale", dequant_scale / scale_scale); dequant_op->Op()->SetOutput( @@ -378,10 +385,86 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const { }; gpd(graph, handler); AddStatis(found_dequant_scale_squash_count); - PrettyLogDetail("--- squashed %d scale with dequant", + PrettyLogDetail("--- squashed %d scale with dequantize op", found_dequant_scale_squash_count); } +// squash scale with quantize +void CPUQuantizeSquashPass::ScaleQuantSquash(Graph* graph) const { + GraphPatternDetector gpd; + patterns::ScaleQuant scale_quant_pattern{gpd.mutable_pattern(), + "scale_quant"}; + scale_quant_pattern(); + + int found_scale_quant_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash scale-quant ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(scale_in, scale_in, scale_quant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, scale_quant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, scale_quant_pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, scale_quant_pattern); + + if (quant_in->outputs.size() == 1 && + scale_op->Op()->GetAttrIfExists("bias") == 0.0) { + auto quant_scale = quant_op->Op()->GetAttrIfExists("Scale"); + auto scale_scale = scale_op->Op()->GetAttrIfExists("scale"); + + PADDLE_ENFORCE_GT( + quant_scale, 0.0f, + platform::errors::InvalidArgument( + "Quantize scale(%f) should have positive value.", quant_scale)); + PADDLE_ENFORCE_NE( + scale_scale, 0.0f, + platform::errors::InvalidArgument( + "Scale(%f) should have a non-zero value", scale_scale)); + + quant_op->Op()->SetAttr("Scale", quant_scale * scale_scale); + quant_op->Op()->SetInput("Input", + std::vector({scale_in->Name()})); + IR_NODE_LINK_TO(scale_in, quant_op); + GraphSafeRemoveNodes(graph, {scale_op, quant_in}); + found_scale_quant_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_scale_quant_squash_count); + PrettyLogDetail("--- squashed %d scale with quantize op", + found_scale_quant_squash_count); +} + +// squash quantize if is before bfloat16 conv2d +void CPUQuantizeSquashPass::QuantizeBf16Conv(Graph* graph) const { + GraphPatternDetector gpd; + patterns::QuantConv pattern{gpd.mutable_pattern(), "quant_conv"}; + pattern(); + + int found_quant_conv_squash_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "squash quant-conv2d ops pair"; + + GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, pattern); + GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_in, conv_in, pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, pattern); + + if (conv_in->outputs.size() == 1 && + quant_op->Op()->GetAttrIfExists("Scale") == 1.0) { + conv_op->Op()->SetInput("Input", + std::vector({quant_in->Name()})); + IR_NODE_LINK_TO(quant_in, conv_op); + GraphSafeRemoveNodes(graph, {quant_op, conv_in}); + found_quant_conv_squash_count++; + } + }; + gpd(graph, handler); + AddStatis(found_quant_conv_squash_count); + PrettyLogDetail("--- squashed %d quantize with bfloat16 conv2d op", + found_quant_conv_squash_count); +} + void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, @@ -389,6 +472,8 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { "The graph in function CPUQuantizeSquashPass::ApplyImpl is null.")); FusePassBase::Init("cpu_quantize_squash_pass", graph); + DequantScaleSquash(graph); + ScaleQuantSquash(graph); std::unordered_map nodes_keep_counter; FindNodesToKeep(graph, &nodes_keep_counter); DequantQuantSquash(graph, &nodes_keep_counter); @@ -396,7 +481,7 @@ void CPUQuantizeSquashPass::ApplyImpl(ir::Graph* graph) const { RequantOpSquash(graph); OpDequantSquash(graph); MultipleQuantizeSquash(graph); - DequantScaleSquash(graph); + QuantizeBf16Conv(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index d1465f9da5c..b34d5062e3e 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -78,6 +78,16 @@ class CPUQuantizeSquashPass : public FusePassBase { */ void DequantScaleSquash(Graph* graph) const; + /* + * Squash scale if scale is before quantize + */ + void ScaleQuantSquash(Graph* graph) const; + + /* + * Squash quantize if is before bfloat16 conv2d + */ + void QuantizeBf16Conv(Graph* graph) const; + const std::string name_scope_{"squash"}; }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index 37af0274ea8..08e2041a9a1 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -24,7 +24,8 @@ namespace ir { void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, const std::vector& inputs, const std::vector& outputs, bool use_mkldnn, - const std::vector scale = {}, float bias = 0.0) { + const std::vector scale = {}, float bias = 0.0, + const std::string& mkldnn_data_type = "float32") { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); op->SetAttr("use_mkldnn", use_mkldnn); @@ -36,6 +37,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]}); if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]}); op->SetOutput("Output", {outputs[0]}); + op->SetAttr("force_fp32_output", false); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); } else if (type == "quantize") { op->SetInput("Input", {inputs[0]}); op->SetOutput("Output", {outputs[0]}); @@ -52,6 +55,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, } else if (type == "concat") { op->SetInput("X", inputs); op->SetOutput("Out", outputs); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); } else if (type == "fc") { op->SetInput("Input", {inputs[0]}); PADDLE_ENFORCE_EQ(inputs.size(), 2UL, @@ -63,6 +67,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetOutput("Out", outputs); if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]); if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]); + op->SetAttr("force_fp32_output", false); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); } else if (type == "scale") { op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); @@ -74,6 +80,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetOutput("Out", {outputs[0]}); if (scale.size() > 0) op->SetAttr("Scale_x", scale[0]); if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]); + op->SetAttr("force_fp32_output", false); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); } } @@ -299,6 +307,20 @@ ProgramDesc BuildDequantScaleProgramDesc(bool use_mkldnn, float dequant_scale, return prog; } +// a->Scale->b +// b->Quant->c +ProgramDesc BuildScaleQuantProgramDesc(bool use_mkldnn, float scale_scale, + float quant_scale, float bias) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "scale", "Scale", {"a"}, {"b"}, use_mkldnn, {scale_scale}, bias); + SetOp(&prog, "quantize", "Quant", {"b"}, {"c"}, use_mkldnn, {quant_scale}); + + return prog; +} + // {x,y}->Matmul->b // b->Dequant->c ProgramDesc BuildMatmulDequantProgramDesc(bool use_mkldnn, @@ -341,6 +363,22 @@ ProgramDesc BuildRequantOpProgramDesc(bool use_mkldnn, float requant_scale_in, return prog; } +// a->Quant->b +// b->Conv2d->c +ProgramDesc BuildQuantConv2dProgramDesc(const bool& use_mkldnn, + const float& quant_scale, + const std::string& mkldnn_data_type) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "quantize", "Quant", {"a"}, {"b"}, use_mkldnn, {quant_scale}); + SetOp(&prog, "conv2d", "Conv2d", {"b"}, {"c"}, use_mkldnn, {}, 0.0f, + mkldnn_data_type); + + return prog; +} + void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, const char* var_name) { auto x = scope->Var(var_name); @@ -664,6 +702,22 @@ TEST(CpuQuantizeSquashPass, dequantize_scale_with_bias) { "Dequant", "Scale", dequant_scale); } +// if scale has no bias +TEST(CpuQuantizeSquashPass, scale_with_no_bias_quantize) { + constexpr auto scale_scale = 1.5432f; + constexpr auto quant_scale = 1.2345f; + constexpr auto bias = 0.0f; + auto use_mkldnn = true; + // remove: dequant out, scale op + auto remove_nodes = 2; + CountNodeTest( + BuildScaleQuantProgramDesc(use_mkldnn, scale_scale, quant_scale, bias), + remove_nodes); + EqualScaleTest( + BuildScaleQuantProgramDesc(use_mkldnn, scale_scale, quant_scale, bias), + "Scale", "Quant", quant_scale * scale_scale); +} + TEST(CpuQuantizeSquashPass, matmul_with_dequant) { auto dequant_scale = 1.2345f; auto use_mkldnn = true; @@ -688,6 +742,17 @@ TEST(CpuQuantizeSquashPass, requantize_with_matmul_fc_conv) { EqualScaleTest(program_desc, "Conv", "Scale_in", requant_scale_in); } +TEST(CpuQuantizeSquashPass, quant_bf16_conv2d) { + auto quant_scale = 1.0f; + auto use_mkldnn = true; + auto mkldnn_data_type = "bfloat16"; + // remove: quant_op, conv_in + auto remove_nodes = 2; + CountNodeTest( + BuildQuantConv2dProgramDesc(use_mkldnn, quant_scale, mkldnn_data_type), + remove_nodes); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index b7291ef3077..2940bc01d73 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -268,6 +268,7 @@ void CpuPassStrategy::EnableMkldnnBfloat16() { if (!use_mkldnn_bfloat16_) { passes_.push_back("cpu_bfloat16_placement_pass"); passes_.push_back("cpu_bfloat16_pass"); + passes_.push_back("cpu_quantize_squash_pass"); } use_mkldnn_bfloat16_ = true; #else diff --git a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc index 33422455ada..4c136a2fc2c 100644 --- a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc @@ -156,4 +156,5 @@ class ReQuantOpKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace, - ops::ReQuantOpKernel, ops::ReQuantOpKernel); + ops::ReQuantOpKernel, ops::ReQuantOpKernel, + ops::ReQuantOpKernel); -- GitLab