From 2579ade45fb0d2d698ee4ee0a8f6f4162f9ddb5f Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Mon, 18 Mar 2019 04:28:11 +0100 Subject: [PATCH] Add cpu_quantize_pass for C-API quantization (#16127) * Add cpu_quantize_pass for C-API quantization test=develop * add cpu_quantize_pass test * fix lint: add include memory unorderd_map and unordered_set test=develop * fuse_relu 1 test=develop * tuned 2 without squash * fixes test=develop * remove unused vars test=develop * refactored test=develop * fix lint c-style cast -> C++ style cast test=develop * remove QuantMax and c style casts test=develop * last usage of QuantMax removed test=develop * Fix Analysis Predictor UT Check if memory_optimize_pass has already been added to the analysis config before adding a new one, so that it is not added multiple times. test=develop * change map to unordered_map fix the forgotten part of cpu_quantize_pass_tester.cc test=develop * removed quantized attribute * fixed cpu_quantize_pass_tester and op attr comments test=develop * removed redundant line test=debug * removed gmock test=develop * fix after merge --- paddle/fluid/framework/ir/CMakeLists.txt | 4 +- .../fluid/framework/ir/cpu_quantize_pass.cc | 239 ++++++++++++++++++ paddle/fluid/framework/ir/cpu_quantize_pass.h | 66 +++++ .../framework/ir/cpu_quantize_pass_tester.cc | 211 ++++++++++++++++ .../framework/ir/graph_pattern_detector.cc | 51 +++- .../framework/ir/graph_pattern_detector.h | 29 +++ paddle/fluid/inference/analysis/argument.h | 6 + .../inference/analysis/ir_pass_manager.cc | 11 +- paddle/fluid/inference/api/analysis_config.cc | 9 +- paddle/fluid/operators/conv_op.cc | 7 + .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 1 + paddle/fluid/operators/pool_op.cc | 7 + 12 files changed, 631 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/framework/ir/cpu_quantize_pass.cc create mode 100644 paddle/fluid/framework/ir/cpu_quantize_pass.h create mode 100644 paddle/fluid/framework/ir/cpu_quantize_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index bfab221a9ec..3808dd5fbae 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -46,6 +46,7 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) pass_library(lock_free_optimize_pass base) +pass_library(cpu_quantize_pass inference) pass_library(cpu_quantize_squash_pass inference) pass_library(fc_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference) @@ -102,10 +103,11 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) +cc_test(test_cpu_quantize_pass SRCS cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor) +cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) if(NOT WIN32) cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) endif() -cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) if (WITH_MKLDNN) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) diff --git a/paddle/fluid/framework/ir/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/cpu_quantize_pass.cc new file mode 100644 index 00000000000..edfaf47f018 --- /dev/null +++ b/paddle/fluid/framework/ir/cpu_quantize_pass.cc @@ -0,0 +1,239 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/cpu_quantize_pass.h" +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +namespace { + +void UnlinkNodes(ir::Node* a, ir::Node* b) { + a->outputs.erase(std::remove(a->outputs.begin(), a->outputs.end(), b), + a->outputs.end()); + b->inputs.erase(std::remove(b->inputs.begin(), b->inputs.end(), a), + b->inputs.end()); +} + +} // namespace + +enum { U8_MAX = 255, S8_MAX = 127 }; + +using EigenVectorArrayMap = Eigen::Map>; +using string::PrettyLogDetail; + +void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input, + std::string input_name, double scale_to_one, + bool is_unsigned, + std::string scale_attr_name) const { + unsigned max = is_unsigned ? U8_MAX : S8_MAX; + float scale = scale_to_one * max; + + // Create quantize output variable + VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); + auto* quantize_out_node = g->CreateVarNode(&quantize_out_desc); + + // create a quantize op node + OpDesc q_desc; + q_desc.SetType("quantize"); + q_desc.SetInput("Input", std::vector({input->Name()})); + q_desc.SetOutput("Output", + std::vector({quantize_out_node->Name()})); + q_desc.SetAttr("Scale", scale); + q_desc.SetAttr("is_negative_input", !is_unsigned); + auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. + + // update op's input + op->Op()->SetInput(input_name, + std::vector({quantize_out_node->Name()})); + + // link quantize op + UnlinkNodes(input, op); + IR_NODE_LINK_TO(input, quantize_op); + IR_NODE_LINK_TO(quantize_op, quantize_out_node); + IR_NODE_LINK_TO(quantize_out_node, op); + + if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); +} + +void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output, + std::string output_name, + double scale_to_one, bool is_unsigned, + std::string scale_attr_name) const { + unsigned max = is_unsigned ? U8_MAX : S8_MAX; + float scale = scale_to_one * max; + + // Create dequantize input variable + VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); + auto* dequantize_in_node = g->CreateVarNode(&dequantize_in_desc); + + // create a dequantize op node for output. + OpDesc deq_desc; + deq_desc.SetType("dequantize"); + deq_desc.SetInput("Input", + std::vector({dequantize_in_node->Name()})); + deq_desc.SetOutput("Output", std::vector({output->Name()})); + deq_desc.SetAttr("Scale", scale); + auto dequantize_op = g->CreateOpNode(&deq_desc); // OpDesc will be copied. + + // update op's output + op->Op()->SetOutput(output_name, + std::vector({dequantize_in_node->Name()})); + + // link dequantize op + UnlinkNodes(op, output); + IR_NODE_LINK_TO(op, dequantize_in_node); + IR_NODE_LINK_TO(dequantize_in_node, dequantize_op); + IR_NODE_LINK_TO(dequantize_op, output); + + if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); +} + +void CPUQuantizePass::QuantizeConv(Graph* graph, + bool with_residual_data) const { + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + patterns::ConvResidual conv_pattern{pattern, name_scope_}; + conv_pattern(with_residual_data); + + int quantize_conv_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "Quantize conv2d op"; + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); + auto* conv_op_desc = conv_op->Op(); + + // skip if should not be quantized + if (!conv_op_desc->HasAttr("use_quantizer") || + !boost::get(conv_op_desc->GetAttr("use_quantizer"))) + return; + + GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); + GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); + + // get scales calculated after warmup, they scale variables to MAX=1.0 + auto scales = Get("quant_var_scales"); + + auto input_scale = scales[conv_input->Name()].second.data()[0]; + bool is_input_unsigned = scales[conv_input->Name()].first; + QuantizeInput(g, conv_op, conv_input, "Input", input_scale, + is_input_unsigned, "Scale_in"); + + auto filter_scale_tensor = scales[conv_filter->Name()].second; + EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data(), + filter_scale_tensor.numel(), 1}; + eigen_tensor *= static_cast(S8_MAX); + std::vector filter_scale{ + filter_scale_tensor.data(), + filter_scale_tensor.data() + filter_scale_tensor.numel()}; + + conv_op->Op()->SetAttr("Scale_weights", filter_scale); + + if (with_residual_data) { + GET_IR_NODE_FROM_SUBGRAPH(conv_residual_data, conv_residual_data, + conv_pattern); + auto residual_scale = + scales[conv_residual_data->Name()].second.data()[0]; + bool is_residual_unsigned = scales[conv_residual_data->Name()].first; + + QuantizeInput(g, conv_op, conv_residual_data, "ResidualData", + residual_scale, is_residual_unsigned, "Scale_in_eltwise"); + } + + auto output_scale = scales[conv_output->Name()].second.data()[0]; + bool is_output_unsigned = scales[conv_output->Name()].first; + DequantizeOutput(g, conv_op, conv_output, "Output", output_scale, + is_output_unsigned, "Scale_out"); + + ++quantize_conv_count; + }; + + gpd(graph, handler); + AddStatis(quantize_conv_count); + + std::stringstream msg_ss; + msg_ss << "--- quantized " << quantize_conv_count << " conv2d ops"; + if (with_residual_data) msg_ss << " with residual connection"; + PrettyLogDetail(msg_ss.str().c_str()); +} + +void CPUQuantizePass::QuantizePool(Graph* graph) const { + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + patterns::Pool pool_pattern{pattern, name_scope_}; + pool_pattern(); + + int quantize_pool_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "Quantize pool2d op"; + GET_IR_NODE_FROM_SUBGRAPH(pool_op, pool_op, pool_pattern); + auto* pool_op_desc = pool_op->Op(); + + // skip if should not be quantized + if (!pool_op_desc->HasAttr("use_quantizer") || + !boost::get(pool_op_desc->GetAttr("use_quantizer"))) + return; + + GET_IR_NODE_FROM_SUBGRAPH(pool_input, pool_input, pool_pattern); + GET_IR_NODE_FROM_SUBGRAPH(pool_output, pool_output, pool_pattern); + + // get scales calculated after warmup, they scale variables to MAX=1.0 + auto scales = Get("quant_var_scales"); + + auto input_scale = scales[pool_input->Name()].second.data()[0]; + bool is_input_unsigned = scales[pool_input->Name()].first; + QuantizeInput(g, pool_op, pool_input, "X", input_scale, is_input_unsigned); + + auto output_scale = scales[pool_output->Name()].second.data()[0]; + bool is_output_unsigned = scales[pool_output->Name()].first; + DequantizeOutput(g, pool_op, pool_output, "Out", output_scale, + is_output_unsigned); + + ++quantize_pool_count; + }; + + gpd(graph, handler); + AddStatis(quantize_pool_count); + + PrettyLogDetail("--- quantized %d pool2d ops", quantize_pool_count); +} + +std::unique_ptr CPUQuantizePass::ApplyImpl( + std::unique_ptr graph) const { + VLOG(3) << "Quantizing the graph."; + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init(name_scope_, graph.get()); + + PADDLE_ENFORCE(param_scope()); + + QuantizeConv(graph.get(), true /* with_residual_data */); + QuantizeConv(graph.get()); + QuantizePool(graph.get()); + + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(cpu_quantize_pass, paddle::framework::ir::CPUQuantizePass) + .RequirePassAttr("quant_var_scales"); diff --git a/paddle/fluid/framework/ir/cpu_quantize_pass.h b/paddle/fluid/framework/ir/cpu_quantize_pass.h new file mode 100644 index 00000000000..9873bb04e13 --- /dev/null +++ b/paddle/fluid/framework/ir/cpu_quantize_pass.h @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Map variable name to tensor of scaling factors scaling it to MAX=1.0. + * bool denotes whether quantization of the variable should be done to unsigned + * type. + */ +using VarQuantScale = + std::unordered_map>; + +/* + * Quantize all supported operators. + */ +class CPUQuantizePass : public FusePassBase { + public: + virtual ~CPUQuantizePass() {} + + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; + + void QuantizeConv(Graph* graph, bool with_residual_data = false) const; + + void QuantizePool(Graph* graph) const; + + void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, + double scale_to_one, bool is_unsigned, + std::string scale_attr_name = "") const; + + void DequantizeOutput(Graph* g, Node* op, Node* output, + std::string output_name, double scale_to_one, + bool is_unsigned, + std::string scale_attr_name = "") const; + + const std::string name_scope_{"quantize"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/cpu_quantize_pass_tester.cc new file mode 100644 index 00000000000..89601be7d1c --- /dev/null +++ b/paddle/fluid/framework/ir/cpu_quantize_pass_tester.cc @@ -0,0 +1,211 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/cpu_quantize_pass.h" +#include +#include "paddle/fluid/framework/naive_executor.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, + const std::vector& inputs, + const std::vector& outputs, bool use_mkldnn, + bool use_quantizer = false) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("name", name); + if (type == "conv2d") { + op->SetInput("Input", {inputs[0]}); + op->SetInput("Filter", {inputs[1]}); + if (inputs.size() > 2) + op->SetInput("Bias", {inputs[2]}); + else + op->SetInput("Bias", {}); + if (inputs.size() > 3) { + op->SetInput("ResidualData", {inputs[3]}); + op->SetAttr("fuse_residual_connection", true); + } else { + op->SetInput("ResidualData", {}); + op->SetAttr("fuse_residual_connection", false); + } + op->SetOutput("Output", {outputs[0]}); + op->SetAttr("use_quantizer", use_quantizer); + op->SetAttr("Scale_in", 1.0f); + op->SetAttr("Scale_out", 1.0f); + op->SetAttr("Scale_weights", std::vector{1.0f}); + } else if (type == "pool2d") { + op->SetInput("X", {inputs[0]}); + op->SetOutput("Out", {outputs[0]}); + op->SetAttr("use_quantizer", use_quantizer); + } else if (type == "dropout") { + op->SetInput("X", {inputs[0]}); + op->SetOutput("Out", {outputs[0]}); + } else if (type == "fc") { + op->SetInput("Input", {inputs[0]}); + if (inputs.size() > 1) op->SetInput("W", {inputs[1]}); + if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]}); + op->SetOutput("Out", {outputs[0]}); + } +} + +static const std::initializer_list variable_names{ + "a", "w1", "c", "d", "w2", "e", "f", "g", + "h", "w3", "b1", "i", "j", "w4", "b2"}; +// (a,w1)->Conv1->c and c->Pool1->d +// +// (d,w2)->Conv2->e and e->Pool2->f +// +// d->Dropout1->g and g->Fc1->h and (h,w3,b1,i)->Conv3->j +// +// (d,w4, b2)->Conv4->i +ProgramDesc BuildProgramDesc(bool use_mkldnn, bool use_quantizer) { + ProgramDesc prog; + for (auto& v : variable_names) { + auto* var = prog.MutableBlock(0)->Var(v); + if (v.find("w") == 0 || v.find("b") == 0) { + var->SetPersistable(true); + } + } + + SetOp(&prog, "conv2d", "Conv1", {"a", "w1"}, {"c"}, use_mkldnn, + use_quantizer); + SetOp(&prog, "pool2d", "Pool1", {"c"}, {"d"}, use_mkldnn, use_quantizer); + + SetOp(&prog, "conv2d", "Conv2", {"d", "w2"}, {"e"}, use_mkldnn, + use_quantizer); + SetOp(&prog, "pool2d", "Pool2", {"e"}, {"f"}, use_mkldnn, use_quantizer); + + SetOp(&prog, "dropout", "Dropout1", {"d"}, {"g"}, use_mkldnn); + SetOp(&prog, "fc", "Fc1", {"g"}, {"h"}, use_mkldnn); + SetOp(&prog, "conv2d", "Conv3", {"h", "w3", "b1", "i"}, {"j"}, use_mkldnn, + use_quantizer); + + SetOp(&prog, "conv2d", "Conv4", {"c", "w4", "b2"}, {"i"}, use_mkldnn, + use_quantizer); + + return prog; +} + +void InitTensorHolder(Scope* scope, const paddle::platform::Place& place, + const char* var_name) { + auto x = scope->Var(var_name); + auto tensor = x->GetMutable(); + tensor->mutable_data(place, proto::VarType::FP32, + ::paddle::memory::Allocator::kDefault, 1); +} + +void MainTest(const ProgramDesc& prog, int conv_count, int pool_count, + int quant_count, int dequant_count, int added_nodes_count, + float scale) { + std::unique_ptr graph(new ir::Graph(prog)); + + // Init scope, as it is used in pass + auto place = paddle::platform::CPUPlace(); + NaiveExecutor exe{place}; + Scope scope; + exe.CreateVariables(prog, 0, true, &scope); + + auto* scales = new VarQuantScale(); + + for (auto& v : variable_names) { + InitTensorHolder(&scope, place, v.c_str()); + LoDTensor tensor; + tensor.Resize({1}); + auto* ptr = tensor.mutable_data(place); + ptr[0] = 2.0; + + (*scales)[v] = std::make_pair(false, std::move(tensor)); + } + + graph->Set(kParamScopeAttr, new framework::Scope*(&scope)); + + auto pass = PassRegistry::Instance().Get("cpu_quantize_pass"); + pass->Set("quant_var_scales", scales); + + int original_nodes_num = graph->Nodes().size(); + + graph = pass->Apply(std::move(graph)); + + int current_nodes_num = graph->Nodes().size(); + + int quantize_nodes_count = 0; + int dequantize_nodes_count = 0; + int conv2d_nodes_count = 0; + int pool2d_nodes_count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (op->Type() == "conv2d") { + conv2d_nodes_count++; + auto op_name = boost::get(op->GetAttr("name")); + EXPECT_EQ(boost::get(op->GetAttr("Scale_in")), scale) + << "Scale_in for node '" + op_name + "'."; + EXPECT_EQ(boost::get(op->GetAttr("Scale_out")), scale) + << "Scale_out for node '" + op_name + "'."; + EXPECT_EQ( + boost::get>(op->GetAttr("Scale_weights"))[0], + scale) + << "Scale_weights for node '" + op_name + "'."; + } else if (op->Type() == "pool2d") { + pool2d_nodes_count++; + } else if (op->Type() == "quantize") { + quantize_nodes_count++; + } else if (op->Type() == "dequantize") { + dequantize_nodes_count++; + } + } + } + EXPECT_EQ(conv2d_nodes_count, conv_count); + EXPECT_EQ(pool2d_nodes_count, pool_count); + EXPECT_EQ(quantize_nodes_count, quant_count); + EXPECT_EQ(dequantize_nodes_count, dequant_count); + EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num); +} + +TEST(CpuQuantizePass, quantize) { + bool use_mkldnn = true; + bool use_quantizer = true; + // (a->QUANT1->IN1,w1)->Conv1->OUT1->DEQUANT1->c and + // c->QUANT2->IN2->Pool1->OUT2->DEQUANT2->d + // + // (d->QUANT3->IN3,w2)->Conv2->OUT3->DEQUANT3->e and + // e->QUANT4->IN4->Pool2->OUT4->DEQUANT4->f + // + // d->Dropout1->g and g->Fc1->h and + // (h->QUANT5->IN5,w3,b1,i->QUANT6->IN6)->Conv3->OUT5->DEQUANT5->j + // + // (d->QUANT7->IN7,w4, b2)->Conv4->DEQUANT6->OUT6->i + // Insert nodes: 7 Quant + 7 IN + 6 OUT + 6 DEQUANT + int added_nodes = 7 + 7 + 6 + 6; + MainTest(BuildProgramDesc(use_mkldnn, use_quantizer), 4, 2, 7, 6, added_nodes, + 2.0f * 127); +} + +TEST(CpuQuantizePass, do_not_quantize) { + bool use_mkldnn = true; + bool use_quantizer = false; + int added_nodes = 0; + MainTest(BuildProgramDesc(use_mkldnn, use_quantizer), 4, 2, 0, 0, added_nodes, + 1.0f); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(cpu_quantize_pass); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 08354b526a0..b653e5a521e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -90,7 +90,8 @@ void GraphPatternDetector::operator()(Graph *graph, ValidateByNodeRole(&subgraphs); if (subgraphs.empty()) return; - PrettyLogEndl(Style::detail(), "--- detect %d subgraphs", subgraphs.size()); + PrettyLogEndl(Style::detail(), "--- detected %d subgraphs", + subgraphs.size()); int id = 0; for (auto &g : subgraphs) { VLOG(3) << "optimizing #" << id++ << " subgraph"; @@ -1074,9 +1075,53 @@ PDNode *patterns::Conv::operator()() { ->AsOutput() ->assert_is_op_output("conv2d", "Output"); - conv_op->LinksFrom({input_var, filter_var}); - conv_op->LinksTo({output_var}); + conv_op->LinksFrom({input_var, filter_var}).LinksTo({output_var}); + return output_var; +} + +PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { + auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); + + if (!with_residual_data) + conv_op->assert_op_attr("fuse_residual_connection", false); + + auto input_var = pattern->NewNode(conv_input_repr()) + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + + auto filter_var = pattern->NewNode(conv_filter_repr()) + ->AsInput() + ->assert_is_op_input("conv2d", "Filter"); + + auto output_var = pattern->NewNode(conv_output_repr()) + ->AsOutput() + ->assert_is_op_output("conv2d", "Output"); + + std::vector links_from{input_var, filter_var}; + + if (with_residual_data) { + auto res_conn_var = pattern->NewNode(conv_residual_data_repr()) + ->AsInput() + ->assert_is_op_input("conv2d", "ResidualData"); + links_from.push_back(res_conn_var); + } + + conv_op->LinksFrom(links_from).LinksTo({output_var}); + return output_var; +} + +PDNode *patterns::Pool::operator()() { + auto pool_op = pattern->NewNode(pool_op_repr())->assert_is_op("pool2d"); + + auto input_var = pattern->NewNode(pool_input_repr()) + ->AsInput() + ->assert_is_op_input("pool2d", "X"); + + auto output_var = pattern->NewNode(pool_output_repr()) + ->AsOutput() + ->assert_is_op_output("pool2d", "Out"); + pool_op->LinksFrom({input_var}).LinksTo({output_var}); return output_var; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 3db4bba10d6..fc30b5b21c5 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -659,6 +659,35 @@ struct Conv : public PatternBase { PATTERN_DECL_NODE(conv_output); }; +// Convolution op with residual data +struct ConvResidual : public PatternBase { + ConvResidual(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_residual") {} + + PDNode* operator()(bool with_residual_data); + + PATTERN_DECL_NODE(conv_op); + PATTERN_DECL_NODE(conv_input); + PATTERN_DECL_NODE(conv_filter); + PATTERN_DECL_NODE(conv_residual_data); + PATTERN_DECL_NODE(conv_output); +}; + +// Pool op +// Forward pass for pooling. +// pool_input is the input. +// pool_output is a result of the operator. +struct Pool : public PatternBase { + Pool(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "pooling") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(pool_op); + PATTERN_DECL_NODE(pool_input); + PATTERN_DECL_NODE(pool_output); +}; + // ElementwiseAdd used in residual connections. // y_var is used and convolution output. // The operator is removed, when residual diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 89e934ae27b..321deccf867 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include "paddle/fluid/framework/ir/graph.h" @@ -38,7 +39,10 @@ namespace paddle { namespace inference { namespace analysis { + using framework::ir::Graph; +using VarQuantScale = + std::unordered_map>; /* * The argument definition of both Pass and PassManagers. @@ -127,6 +131,8 @@ struct Argument { // Pass a set of op types to enable its mkldnn kernel DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes, std::unordered_set); + // Scales for variables to be quantized + DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale); // Passed from config. DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 1cdb4881fbc..8fd86b2cc56 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include +#include #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" @@ -55,14 +56,14 @@ void IRPassManager::CreatePasses(Argument *argument, ".dot"; pass->Set("graph_viz_path", new std::string(std::move(dot_file_path))); pass_num++; - } - if (pass_name == "mkldnn_placement_pass") { + } else if (pass_name == "mkldnn_placement_pass") { pass->Set("mkldnn_enabled_op_types", new std::unordered_set( argument->mkldnn_enabled_op_types())); - } - - if (pass_name == "tensorrt_subgraph_pass") { + } else if (pass_name == "cpu_quantize_pass") { + pass->Set("quant_var_scales", + new VarQuantScale(argument->quant_var_scales())); + } else if (pass_name == "tensorrt_subgraph_pass") { pass->Set("workspace_size", new int(argument->tensorrt_workspace_size())); pass->Set("max_batch_size", new int(argument->tensorrt_max_batch_size())); pass->Set("min_subgraph_size", diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 77411112220..92526f4e74a 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -219,7 +219,14 @@ void AnalysisConfig::Update() { } if (enable_memory_optim_) { - pass_builder()->AppendAnalysisPass("memory_optimize_pass"); + auto analysis_passes = pass_builder()->AnalysisPasses(); + auto memory_opti_pass_name = "memory_optimize_pass"; + bool already_exists = + std::find(analysis_passes.begin(), analysis_passes.end(), + memory_opti_pass_name) != analysis_passes.end(); + if (!already_exists) { + pass_builder()->AppendAnalysisPass(memory_opti_pass_name); + } } if (ir_debug_) { diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index ca6bc4df0fe..c6121d00dae 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/conv_op.h" +#include #include #include @@ -194,6 +195,12 @@ void Conv2DOpMaker::Make() { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("use_quantizer", + "(bool, default false) " + "Set to true for operators that should be quantized and use " + "int8 kernel. " + "Only used on CPU.") + .SetDefault(false); AddAttr("fuse_relu", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); AddAttr("fuse_residual_connection", diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 14ca3e8073b..8d96ae7e421 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -592,6 +592,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { platform::SetDstMemoryHandler(ctx, output, handler, &dst_memory_p); } else { + need_s8_to_u8 = fuse_relu; platform::SetDstMemoryHandler(ctx, output, handler, &dst_memory_p); } diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 0a0ece162cc..7963c27a015 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/pool_op.h" +#include #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cudnn_helper.h" #endif @@ -212,6 +213,12 @@ void Pool2dOpMaker::Make() { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("use_quantizer", + "(bool, default false) " + "Set to true for operators that should be quantized and use " + "int8 kernel. " + "Only used on CPU.") + .SetDefault(false); AddAttr( "data_format", "(string, default NCHW) Only used in " -- GitLab