From 96845d216884f46a84ee072d677a137f4371cba4 Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Mon, 27 May 2019 15:27:58 +0200 Subject: [PATCH] add Concat quantization (#17448) * add Concat quantization add unit test for quantizing concat fix for wrong value when the input is not in map of calculated scales add use_quantizer to concat_op.cc add scale_algo rules for concat test=develop * missing fix for multiple inputs quantize-squash * wojtuss review fix: adding comment test=develop --- .../framework/ir/graph_pattern_detector.cc | 11 ++ .../framework/ir/graph_pattern_detector.h | 13 +++ .../framework/ir/mkldnn/cpu_quantize_pass.cc | 91 ++++++++++++++++ .../framework/ir/mkldnn/cpu_quantize_pass.h | 7 ++ .../ir/mkldnn/cpu_quantize_pass_tester.cc | 100 ++++++++++++++++++ .../ir/mkldnn/cpu_quantize_squash_pass.cc | 14 +-- .../fluid/inference/api/mkldnn_quantizer.cc | 62 ++++++----- .../inference/api/mkldnn_quantizer_config.cc | 7 +- paddle/fluid/operators/concat_op.cc | 6 ++ 9 files changed, 272 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index d50ca636035..16a0f0d03fc 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1214,6 +1214,17 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { return out_var; } +PDNode *patterns::Concat::operator()() { + auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat"); + + auto output_var = pattern->NewNode(concat_out_repr()) + ->AsOutput() + ->assert_is_op_output("concat", "Out"); + + concat_op->LinksTo({output_var}); + return output_var; +} + PDNode *patterns::ConcatReLU::operator()() { auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat"); auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 41f9d128585..4a90f086fe4 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -747,6 +747,19 @@ struct ElementwiseAdd : public PatternBase { PATTERN_DECL_NODE(elementwise_add_out); }; +// Concat op +// Forward pass for concat. +// concat_out is a result of the operator. +struct Concat : public PatternBase { + Concat(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "concat") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(concat_op); + PATTERN_DECL_NODE(concat_out); +}; + // Concat + ReLU // named nodes: // concat_op, concat_out, relu_op, relu_out diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index dff98e523ac..dd3ee50e040 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h" +#include #include #include #include "paddle/fluid/framework/eigen.h" @@ -72,6 +73,53 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input, if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); } +void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name, + VarQuantScale* scales, bool are_unsigned, + std::string scale_attr_name) const { + auto inputs = op->inputs; + PADDLE_ENFORCE_GE(inputs.size(), 1); + + // create a quantize op desc prototype + OpDesc q_desc; + q_desc.SetType("quantize"); + + std::vector quantize_out_nodes(inputs.size()); + std::vector quantize_out_node_names(inputs.size()); + + double scale_min = std::numeric_limits::max(); + for (const auto& input : inputs) { + double scale = (*scales)[input->Name()].second.data()[0]; + if (scale < scale_min) scale_min = scale; + } + unsigned max = are_unsigned ? U8_MAX : S8_MAX; + float scale = scale_min * max; + + for (size_t i = 0; i < inputs.size(); i++) { + // Create quantize output variable + VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); + quantize_out_nodes[i] = g->CreateVarNode(&quantize_out_desc); + quantize_out_node_names[i] = quantize_out_nodes[i]->Name(); + + q_desc.SetAttr("Scale", scale); + q_desc.SetInput("Input", std::vector({inputs[i]->Name()})); + q_desc.SetOutput("Output", + std::vector({quantize_out_node_names[i]})); + q_desc.SetAttr("is_negative_input", !are_unsigned); + auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. + + // link quantize op + UnlinkNodes(inputs[i], op); + IR_NODE_LINK_TO(inputs[i], quantize_op); + IR_NODE_LINK_TO(quantize_op, quantize_out_nodes[i]); + IR_NODE_LINK_TO(quantize_out_nodes[i], op); + } + + // update op's input + op->Op()->SetInput(input_name, quantize_out_node_names); + + if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); +} + void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output, std::string output_name, double scale_to_one, bool is_unsigned, @@ -216,6 +264,48 @@ void CPUQuantizePass::QuantizePool(Graph* graph) const { PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count); } +void CPUQuantizePass::QuantizeConcat(Graph* graph) const { + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + patterns::Concat concat_pattern{pattern, name_scope_}; + concat_pattern(); + + int quantize_concat_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "Quantize concat op"; + GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_pattern); + auto* concat_op_desc = concat_op->Op(); + + // skip if should not be quantized + if (!concat_op_desc->HasAttr("use_quantizer") || + !boost::get(concat_op_desc->GetAttr("use_quantizer"))) + return; + + GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, concat_pattern); + + // get scales calculated after warmup, they scale variables to MAX=1.0 + auto scales = Get("quant_var_scales"); + + // if all inputs were unsigned, then the output was set to unsigned + // during the scale calculation step + bool are_all_inputs_unsigned = scales[concat_out->Name()].first; + QuantizeInputs(g, concat_op, "X", &scales, are_all_inputs_unsigned); + + auto output_scale = scales[concat_out->Name()].second.data()[0]; + + DequantizeOutput(g, concat_op, concat_out, "Out", output_scale, + are_all_inputs_unsigned); + + ++quantize_concat_count; + }; + + gpd(graph, handler); + AddStatis(quantize_concat_count); + + PrettyLogDetail("--- quantized %d concat ops", quantize_concat_count); +} + void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Quantizing the graph."; PADDLE_ENFORCE(graph); @@ -226,6 +316,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { QuantizeConv(graph, false /* with_residual_data */); QuantizeConv(graph, true /* with_residual_data */); QuantizePool(graph); + QuantizeConcat(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index a178c4dc363..61a28fd3131 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -48,10 +48,17 @@ class CPUQuantizePass : public FusePassBase { void QuantizePool(Graph* graph) const; + void QuantizeConcat(Graph* graph) const; + void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, double scale_to_one, bool is_unsigned, std::string scale_attr_name = "") const; + // quantize all inputs of given name with the same (minimum) scale + void QuantizeInputs(Graph* g, Node* op, std::string input_name, + VarQuantScale* scales, bool are_unsigned, + std::string scale_attr_name = "") const; + void DequantizeOutput(Graph* g, Node* op, Node* output, std::string output_name, double scale_to_one, bool is_unsigned, diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index 7d9d0ead0fe..c46ffad036d 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -60,9 +60,14 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, if (inputs.size() > 1) op->SetInput("W", {inputs[1]}); if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]}); op->SetOutput("Out", {outputs[0]}); + } else if (type == "concat") { + op->SetInput("X", inputs); + op->SetOutput("Out", outputs); + op->SetAttr("use_quantizer", use_quantizer); } } +namespace { static const std::initializer_list variable_names{ "a", "w1", "c", "d", "w2", "e", "f", "g", "h", "w3", "b1", "i", "j", "w4", "b2"}; @@ -204,6 +209,101 @@ TEST(CpuQuantizePass, do_not_quantize) { 1.0f); } +} // namespace + +namespace { +static const std::initializer_list variable_names_concat = { + "a1", "b1", "a2", "b2", "c", "d"}; + +// a1->Pool1->b1 +// a2->Pool2->b2 +// (b1,b2)->Concat->c +// c->Pool3->d +ProgramDesc BuildProgramDescConcat() { + ProgramDesc prog; + + SetOp(&prog, "pool2d", "Pool1", {"a1"}, {"b1"}, true, false); + SetOp(&prog, "pool2d", "Pool2", {"a2"}, {"b2"}, true, false); + SetOp(&prog, "concat", "Concat", {"b1", "b2"}, {"c"}, true, true); + SetOp(&prog, "pool2d", "Pool3", {"c"}, {"d"}, true, false); + + return prog; +} + +void MainTestConcat(const ProgramDesc& prog, int pool_count, int concat_count, + int quant_count, int dequant_count, int added_nodes_count) { + std::unique_ptr graph(new ir::Graph(prog)); + + // Init scope, as it is used in pass + auto place = paddle::platform::CPUPlace(); + NaiveExecutor exe{place}; + Scope scope; + exe.CreateVariables(prog, 0, true, &scope); + + auto* scales = new VarQuantScale(); + + for (auto& v : variable_names_concat) { + InitTensorHolder(&scope, place, v.c_str()); + LoDTensor tensor; + tensor.Resize({1}); + auto* ptr = tensor.mutable_data(place); + ptr[0] = 2.0; + + (*scales)[v] = std::make_pair(false, std::move(tensor)); + } + + graph->SetNotOwned(kParamScopeAttr, &scope); + + auto pass = PassRegistry::Instance().Get("cpu_quantize_pass"); + pass->Set("quant_var_scales", scales); + + int original_nodes_num = graph->Nodes().size(); + + graph.reset(pass->Apply(graph.release())); + + int current_nodes_num = graph->Nodes().size(); + + int quantize_nodes_count = 0; + int dequantize_nodes_count = 0; + int concat_nodes_count = 0; + int pool2d_nodes_count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (op->Type() == "concat") { + concat_nodes_count++; + } else if (op->Type() == "pool2d") { + pool2d_nodes_count++; + } else if (op->Type() == "quantize") { + quantize_nodes_count++; + } else if (op->Type() == "dequantize") { + dequantize_nodes_count++; + } + } + } + EXPECT_EQ(concat_nodes_count, concat_count); + EXPECT_EQ(pool2d_nodes_count, pool_count); + EXPECT_EQ(quantize_nodes_count, quant_count); + EXPECT_EQ(dequantize_nodes_count, dequant_count); + EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num); +} + +TEST(CpuQuantizePass, concat) { + // a1->Pool1->b1 + // a2->Pool2->b2 + // (b1->QUANT1->IN1, b2->QUANT2->IN2)->Concat->c + // c->OUT1->DEQUANT1->Pool3->d + int pool_count = 3; + int concat_count = 1; + int quant_count = 2; + int dequant_count = 1; + int added_nodes_count = 6; + MainTestConcat(BuildProgramDescConcat(), pool_count, concat_count, + quant_count, dequant_count, added_nodes_count); +} + +} // namespace + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index de2e2d744c3..2270e2b5cc5 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -14,6 +14,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.h" +#include #include #include #include "paddle/fluid/platform/enforce.h" @@ -81,15 +82,10 @@ void CPUQuantizeSquashPass::Squash( auto quant_out_var_name = quant_out->Name(); auto next_op_inputs = next_op_desc->InputNames(); for (const auto& name : next_op_inputs) { - if (next_op_desc->Inputs().count(name) == 0 || - next_op_desc->Input(name).size() == 0) - continue; - auto var_name = next_op_desc->Input(name)[0]; - if (var_name.compare(quant_out_var_name) == 0) { - next_op_desc->SetInput( - name, std::vector({dequant_in->Name()})); - break; - } + auto input_names = next_op_desc->Input(name); + std::replace(input_names.begin(), input_names.end(), quant_out_var_name, + dequant_in->Name()); + next_op_desc->SetInput(name, input_names); } if (keep_dequant) diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 0765d300f45..df9678d693a 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -50,40 +50,46 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() { auto glambda = [&](const VariableNameMap& connections, bool is_output) { for (auto const& conn : connections) { - if (conn.second.size() == 0) continue; - auto& var_name = conn.second[0]; - - // skip if scale already computed - if (scales_.find(var_name) != scales_.end()) return; - - auto* var = predictor_.sub_scope_->FindVar(var_name); - PADDLE_ENFORCE(var, "%s is not in the scope", var_name); - PADDLE_ENFORCE(var->IsType(), - "Only support lod tensor now."); - LoDTensor* var_tensor = var->GetMutable(); - - // force unsigned type if already know it - bool is_unsigned = false; - if (is_output && op->Type() == "conv2d") { - // output of conv2d with relu must be unsigned - is_unsigned = op->HasAttr("fuse_relu") && - boost::get(op->GetAttr("fuse_relu")); - } else if (is_output && op->Type() == "pool2d") { - // output of pool2d with unsigned input must be unsigned - auto input_var_name = op->Input("X")[0]; - if (scales_.find(input_var_name) != scales_.end()) { - is_unsigned = scales_[input_var_name].first; + for (const auto& var_name : conn.second) { + // skip if scale already computed + if (scales_.find(var_name) != scales_.end()) return; + + auto* var = predictor_.sub_scope_->FindVar(var_name); + PADDLE_ENFORCE(var, "%s is not in the scope", var_name); + PADDLE_ENFORCE(var->IsType(), + "Only support lod tensor now."); + LoDTensor* var_tensor = var->GetMutable(); + + // force unsigned type if already know it + bool is_unsigned = false; + if (is_output && op->Type() == "conv2d") { + // output of conv2d with relu must be unsigned + is_unsigned = op->HasAttr("fuse_relu") && + boost::get(op->GetAttr("fuse_relu")); + } else if (is_output && op->Type() == "relu") { + is_unsigned = true; + } else if (is_output && + (op->Type() == "pool2d" || op->Type() == "transpose2" || + op->Type() == "reshape2" || op->Type() == "concat")) { + // output of ops with unsigned input must be unsigned + is_unsigned = true; + for (auto input_var_name : op->Input("X")) { + PADDLE_ENFORCE(scales_.find(input_var_name) != scales_.end(), + "Input scales must be calculated before the " + "output scales to infer if output is unsigned."); + is_unsigned = is_unsigned && scales_[input_var_name].first; + } } - } - CalculateSingleScale(op->Type(), conn.first, var_name, *var_tensor, - is_unsigned); + CalculateSingleScale(op->Type(), conn.first, var_name, *var_tensor, + is_unsigned); + } } }; - // handle outputs first so unsigned outputs could be inferred - glambda(connections_out, true /* is_output */); + // handle inputs first to let is_unsigned be inferred for the outputs glambda(connections_in, false /* is_output */); + glambda(connections_out, true /* is_output */); } } diff --git a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc index f9ff542d86d..a7cb785fe95 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc @@ -22,10 +22,13 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { rules_["conv2d"]["Filter"] = ScaleAlgo::MAX_CH; rules_["conv2d"]["Bias"] = ScaleAlgo::NONE; // do not compute scale rules_["conv2d"]["ResidualData"] = ScaleAlgo::KL; - rules_["conv2d"]["Output"] = ScaleAlgo::KL; // do not compute scale + rules_["conv2d"]["Output"] = ScaleAlgo::KL; rules_["pool2d"]["X"] = ScaleAlgo::KL; - rules_["pool2d"]["Out"] = ScaleAlgo::KL; // do not compute scale + rules_["pool2d"]["Out"] = ScaleAlgo::KL; + + rules_["concat"]["X"] = ScaleAlgo::KL; + rules_["concat"]["Out"] = ScaleAlgo::KL; } ScaleAlgo MkldnnQuantizerConfig::scale_algo( diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index ad7a84009c9..2e9cfaa599e 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -117,6 +117,12 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("axis", "The axis along which the input tensors will be concatenated.") .SetDefault(0); + AddAttr("use_quantizer", + "(bool, default false) " + "Set to true for operators that should be quantized and use " + "int8 kernel. " + "Only used on CPU.") + .SetDefault(false); AddComment(R"DOC( Concat Operator. -- GitLab