diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index b38f30f39df608f99dc4c6e2642d971e48e3f22b..7bcb127450a351000cbb32c53d47d9abb9ab5afd 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -192,7 +192,6 @@ void MainTest(const ProgramDesc& prog, int original_nodes_num, current_nodes_num; PreparePass(&graph, prog, variable_names, &original_nodes_num, ¤t_nodes_num, var_without_scale, var_signed); - std::unordered_map actual_operators; for (auto* node : graph->Nodes()) { if (node->IsOp()) { auto* op = node->Op(); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index 2b9419a5502f1cbc79c346d2587feda52508439a..64f9dfdc0801a06d9f7492390f771b5749a96572 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -104,6 +104,34 @@ void CPUQuantizeSquashPass::FindNodesToKeep( AddStatis(found_count); } +bool CPUQuantizeSquashPass::IsDequantizeInputUint8( + const Node* dequant_in) const { + PADDLE_ENFORCE_EQ( + dequant_in->inputs.size(), 1, + platform::errors::InvalidArgument( + "Dequantize (id: %f) should have only one input.", dequant_in->id())); + if (dequant_in->inputs[0]->IsOp()) { + auto prev_op = dequant_in->inputs[0]->Op(); + std::string act_name; + if (prev_op->Type() == "relu") { + return true; + } else { + if (prev_op->Type() == "conv2d") { + act_name = "fuse_activation"; + } else if (prev_op->Type() == "fc") { + act_name = "activation_type"; + } + if (!act_name.empty()) { + auto act = prev_op->GetAttrIfExists(act_name); + if (act == "relu" || act == "relu6") { + return true; + } + } + } + } + return false; +} + void CPUQuantizeSquashPass::DequantQuantSquash( Graph* graph, std::unordered_map* nodes_keep_counter) const { @@ -123,6 +151,12 @@ void CPUQuantizeSquashPass::DequantQuantSquash( GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern); GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern); + // Don't squash if e.g. just one concat input is unsigned + if (IsDequantizeInputUint8(dequant_in) && + !quant_op->Op()->GetAttrIfExists("is_negative_input")) { + return; + } + auto* next_op_desc = next_op->Op(); float dequant_scale = BOOST_GET_CONST(float, dequant_op->Op()->GetAttr("Scale")); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h index abd0f741b76317fba96748a2ed0b2182b59696bb..d668c222a4ecdba481b7c07905a43fa5cda479b5 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h @@ -43,6 +43,11 @@ class CPUQuantizeSquashPass : public FusePassBase { Graph* graph, std::unordered_map* nodes_keep_counter) const; + /* + * Check if input to dequantize is uint8 + */ + bool IsDequantizeInputUint8(const Node* dequant_in) const; + /* * Squash dequantize-quantize ops pairs into requantize or nothing */ diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index f1352ebaad6d8df6e0d535a364f83e3b55cb9f93..f89893c050b8cefc6644a170532adb1d631e6a3b 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -26,12 +26,26 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, const std::vector& outputs, bool use_mkldnn, const std::vector scale = {}, float bias = 0.0, const std::string& mkldnn_data_type = "float32", - bool bias_after_scale = false, int groups = 1) { + bool bias_after_scale = false, int groups = 1, + bool is_negative_input = true) { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("name", name); - if (type == "conv2d") { + if (type != "dropout" && type != "quantize" && type != "dequantize") { + op->SetAttr("mkldnn_data_type", mkldnn_data_type); + } + if (type == "pool2d") { + op->SetInput("X", {inputs[0]}); + op->SetOutput("Out", {outputs[0]}); + if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]); + if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]); + } else if (type == "relu") { + op->SetInput("X", {inputs[0]}); + op->SetOutput("Out", {outputs[0]}); + if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]); + if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]); + } else if (type == "conv2d") { if (scale.size() > 0) op->SetAttr("Scale_in", scale[0]); if (scale.size() > 1) op->SetAttr("Scale_out", scale[1]); op->SetInput("Input", {inputs[0]}); @@ -48,11 +62,11 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetAttr("padding_algorithm", std::string("EXPLICIT")); op->SetAttr("data_format", std::string("NCHW")); op->SetAttr("force_fp32_output", false); - op->SetAttr("mkldnn_data_type", mkldnn_data_type); } else if (type == "quantize") { op->SetInput("Input", {inputs[0]}); op->SetOutput("Output", {outputs[0]}); op->SetAttr("Scale", scale[0]); + op->SetAttr("is_negative_input", is_negative_input); } else if (type == "dequantize") { op->SetInput("Input", {inputs[0]}); op->SetOutput("Output", {outputs[0]}); @@ -121,7 +135,8 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out, } static const std::initializer_list variable_names{ - "a", "b", "c", "d", "e", "f", "g", "h", "i", "x", "y", "w1", "w2"}; + "a", "b", "c", "d", "e", "f", "g", "h", + "i", "j", "k", "l", "x", "y", "w1", "w2"}; // a->Conv1(scale1)->b // b->Dequant(scale1)->c @@ -219,6 +234,35 @@ ProgramDesc BuildConvMultiRequantProgramDesc(bool use_mkldnn, float scale_out, return prog; } +/* a->pool2d->b->Dequant->c(s8)->Quant->d-\ + * e->relu->f->Dequant->g(u8)->Quant->h--Concat1->x + * i->pool2d->j->Dequant->k(s8)->Quant->l-/ + */ +ProgramDesc BuildConvS8U8S8ConcatProgramDesc(float scale_out, float scale) { + ProgramDesc prog; + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "pool2d", "Pool2d1", {"a"}, {"b"}, true, {scale, scale_out}); + SetOp(&prog, "relu", "Relu1", {"e"}, {"f"}, true, {scale, scale_out}); + SetOp(&prog, "pool2d", "Pool2d2", {"i"}, {"j"}, true, {scale, scale_out}); + + SetOp(&prog, "dequantize", "Dequant1", {"b"}, {"c"}, true, + {scale, scale_out}); + SetOp(&prog, "dequantize", "Dequant2", {"f"}, {"g"}, true, + {scale, scale_out}); + SetOp(&prog, "dequantize", "Dequant3", {"j"}, {"k"}, true, + {scale, scale_out}); + + SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, true, {scale, scale_out}); + SetOp(&prog, "quantize", "Quant2", {"g"}, {"h"}, true, {scale, scale_out}, + 0.0, "float32", false, 1, false); + SetOp(&prog, "quantize", "Quant3", {"k"}, {"l"}, true, {scale, scale_out}); + + SetOp(&prog, "concat", "Concat1", {"d", "h", "l"}, {"x"}, true); + return prog; +} + // a->Conv1->b // b->Dequant1(Scale1)->c // c->Concat @@ -426,6 +470,31 @@ void CountNodeTest(const ProgramDesc& prog, int removed_nodes_num) { EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num); } +void CheckNodesTest(const ProgramDesc& prog, + std::unordered_map expected_operators, + const int removed_nodes_num) { + std::unique_ptr graph(new ir::Graph(prog)); + PrepareGraph(&graph, prog); + + int original_nodes_num = graph->Nodes().size(); + RegisterPass(&graph); + int current_nodes_num = graph->Nodes().size(); + + EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num); + + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (expected_operators.count(op->Type()) > 0) { + expected_operators[op->Type()]--; + } + } + } + for (auto const& pair : expected_operators) { + EXPECT_EQ(pair.second, 0) << " " << pair.first; + } +} + // check op->scale_out void EqualScaleTest(const ProgramDesc& prog, const std::string& op_name, const std::string& scale_name, float scale) { @@ -764,6 +833,18 @@ TEST(CpuQuantizeSquashPass, quant_bf16_conv2d) { remove_nodes); } +TEST(CpuQuantizeSquashPass, dont_squash_u8_dequant_s8_quant_input_to_concat) { + // removed 2 x 4 (dequantize_op, dequantize_out, quantize, quantize_out) + auto remove_nodes = 8; + std::unordered_map expected_operators = {{"concat", 1}, + {"quantize", 1}, + {"dequantize", 1}, + {"relu", 1}, + {"pool2d", 2}}; + CheckNodesTest(BuildConvS8U8S8ConcatProgramDesc(1.2f, 1.2f), + expected_operators, remove_nodes); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index ef9d03d1dcbafae12675422978b998bf16bf54a0..9d22e1b4b520cd29c9265979344d6d5244feebca 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -116,11 +116,15 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs( // force unsigned type if already know it bool is_unsigned = false; bool compute_scale = true; - if (op->Type() == "conv2d" || op->Type() == "fc") { + if (op->Type() == "conv2d") { // output of conv2d with relu must be unsigned std::string fuse_activation = op->GetAttrIfExists("fuse_activation"); is_unsigned = (fuse_activation == "relu" || fuse_activation == "relu6"); + } else if (op->Type() == "fc") { + std::string activation_type = + op->GetAttrIfExists("activation_type"); + is_unsigned = (activation_type == "relu" || activation_type == "relu6"); } else if (op->Type() == "relu") { is_unsigned = true; } else if (op->Type() == "transpose2" || op->Type() == "reshape2" ||