From 0473cdb8b412d8ffaefa043e48e3578c489b4f86 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Mon, 2 Dec 2019 11:25:08 +0800 Subject: [PATCH] CHERRY_PICK: TRT int8: refine trt int8 for dynamic range set (#21112) (#21449) --- .../ir/delete_quant_dequant_op_pass.cc | 21 ++++++++ paddle/fluid/framework/ir/fc_fuse_pass.cc | 2 +- .../ir/quant_conv2d_dequant_fuse_pass.cc | 6 ++- .../inference/api/paddle_pass_builder.cc | 7 +-- .../inference/tensorrt/convert/conv2d_op.cc | 10 +--- .../tensorrt/convert/elementwise_op.cc | 17 ++++--- .../fluid/inference/tensorrt/convert/fc_op.cc | 9 +--- .../tensorrt/convert/leaky_relu_op.cc | 7 +++ .../inference/tensorrt/convert/pool2d_op.cc | 7 +-- paddle/fluid/inference/tensorrt/engine.cc | 26 +++++++--- paddle/fluid/inference/tensorrt/op_teller.cc | 4 ++ .../fluid/inference/tests/api/CMakeLists.txt | 8 +++ .../tests/api/trt_quant_int8_test.cc | 50 +++++++++++++++++++ 13 files changed, 138 insertions(+), 36 deletions(-) create mode 100644 paddle/fluid/inference/tests/api/trt_quant_int8_test.cc diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc index 3d4df87ab7..4dfbd5e00c 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc @@ -39,6 +39,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { patterns::DeleteQuantDequantOpPattern pattern(gpd.mutable_pattern(), pattern_name); pattern(); + auto* scope = param_scope(); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -47,10 +48,29 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { std::string any_op_out_name = any_op_out->Var()->Name(); std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name(); + std::string input_scale_var_name = + quant_dequant_op->Op()->Input("InScale").front(); + const LoDTensor& input_scale_tensor = + scope->FindVar(input_scale_var_name)->Get(); + + const float* input_scale_data = input_scale_tensor.data(); + float input_scale = input_scale_data[0]; auto* any_op2_desc = any_op2->Op(); // auto input_args_names = any_op2_desc->InputArgumentNames(); auto var_map = any_op2_desc->Inputs(); + std::string arg_name = ""; + for (auto& name_m : var_map) { + if (std::find(name_m.second.begin(), name_m.second.end(), + quant_dequant_op_out_name) != name_m.second.end()) { + arg_name = name_m.first; + } + } + CHECK(arg_name.size() > 0) << "can not find the input " + << quant_dequant_op_out_name; + any_op2_desc->SetAttr("enable_int8", true); + any_op2_desc->SetAttr(arg_name + "_scale", input_scale); + // modify the any_op2's inputs for (auto& name_m : var_map) { if (std::find(name_m.second.begin(), name_m.second.end(), quant_dequant_op_out_name) != name_m.second.end()) { @@ -65,6 +85,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { any_op2_desc->Flush(); } } + any_op2_desc->Flush(); // Delete the unneeded nodes. GraphSafeRemoveNodes(graph, {quant_dequant_op, quant_dequant_op_out, diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index b53e6a250c..1cb42fbe43 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -99,7 +99,7 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { auto* mul_op_desc = mul->Op(); if (mul_op_desc->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", mul_op_desc->GetAttr("enable_int8")); - desc.SetAttr("input_scale", mul_op_desc->GetAttr("input_scale")); + desc.SetAttr("Input_scale", mul_op_desc->GetAttr("X_scale")); desc.SetAttr("weight_scale", mul_op_desc->GetAttr("weight_scale")); if (mul_op_desc->HasAttr("out_scale")) desc.SetAttr("out_scale", mul_op_desc->GetAttr("out_scale")); diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index 62fba440ed..d0ca452238 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -140,22 +140,24 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times, framework::OpDesc new_op_desc(base_op_desc, nullptr); new_op_desc.SetType(quantized_op_type); + new_op_desc.SetAttr("enable_int8", true); if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" || quantized_op_type == "depthwise_conv2d") { new_op_desc.SetInput("Input", {new_input}); + new_op_desc.SetAttr("Input_scale", input_scale); new_op_desc.SetOutput("Output", {new_output}); } else if (quantized_op_type == "fc") { new_op_desc.SetInput("Input", {new_input}); + new_op_desc.SetAttr("Input_scale", input_scale); new_op_desc.SetOutput("Out", {new_output}); } else if (quantized_op_type == "mul") { new_op_desc.SetInput("X", {new_input}); + new_op_desc.SetAttr("X_scale", input_scale); new_op_desc.SetOutput("Out", {new_output}); } - new_op_desc.SetAttr("enable_int8", true); - new_op_desc.SetAttr("input_scale", input_scale); new_op_desc.SetAttr("weight_scale", weight_scale); new_op_desc.Flush(); auto* new_op = graph->CreateOpNode(&new_op_desc); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 959c0c00c2..459fef40f1 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -76,9 +76,10 @@ const std::vector kTRTSubgraphPasses({ "shuffle_channel_detect_pass", // "quant_conv2d_dequant_fuse_pass", // "delete_quant_dequant_op_pass", // - // "fc_fuse_pass", // - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + "conv_bn_fuse_pass", // + "fc_fuse_pass", // + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 "conv_elementwise_add_act_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 73bfa800f0..fd60d844ac 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -40,7 +40,8 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, if (enable_int8) { #if IS_TRT_VERSION_GE(5000) - float in_scale = boost::get(op_desc.GetAttr("input_scale")); + CHECK(op_desc.HasAttr("Input_scale")); + float in_scale = boost::get(op_desc.GetAttr("Input_scale")); auto weight_scale = boost::get>(op_desc.GetAttr("weight_scale")); weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, @@ -89,13 +90,6 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, layer->getOutput(0)->setName(output_name.c_str()); engine->SetITensor(output_name, layer->getOutput(0)); -#if IS_TRT_VERSION_GE(5000) - if (enable_int8) { - float output_scale = boost::get(op_desc.GetAttr("out_scale")); - engine->SetTensorDynamicRange(layer->getOutput(0), output_scale); - } -#endif - if (test_mode) { engine->DeclareOutput(output_name); } diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index c61dd753a3..840369976d 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -110,10 +110,11 @@ class ElementwiseWeightOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "elementwise_" + op_type_, {output_name}, test_mode); - if (op_desc.HasAttr("out_scale")) { + if (op_desc.HasAttr("enable_int8")) { #if IS_TRT_VERSION_GE(5000) - float out_scale = boost::get(op_desc.GetAttr("out_scale")); - engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); + CHECK(op_desc.HasAttr("X_scale")); + float x_scale = boost::get(op_desc.GetAttr("X_scale")); + engine_->SetTensorDynamicRange(X, x_scale); #endif } } @@ -169,10 +170,14 @@ class ElementwiseTensorOpConverter : public OpConverter { layer = plugin_layer; } RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); - if (op_desc.HasAttr("out_scale")) { + if (op_desc.HasAttr("enable_int8")) { #if IS_TRT_VERSION_GE(5000) - float out_scale = boost::get(op_desc.GetAttr("out_scale")); - engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); + CHECK(op_desc.HasAttr("X_scale")); + CHECK(op_desc.HasAttr("Y_scale")); + float x_scale = boost::get(op_desc.GetAttr("X_scale")); + float y_scale = boost::get(op_desc.GetAttr("Y_scale")); + engine_->SetTensorDynamicRange(X, x_scale); + engine_->SetTensorDynamicRange(Y, y_scale); #endif } } diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index ea108d6a07..ec21bb5534 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -77,7 +77,8 @@ class FcOpConverter : public OpConverter { bool enable_int8 = boost::get(op_desc.HasAttr("enable_int8")); if (enable_int8) { #if IS_TRT_VERSION_GE(5000) - float in_scale = boost::get(op_desc.GetAttr("input_scale")); + CHECK(op_desc.HasAttr(i_name + "_scale")); + float in_scale = boost::get(op_desc.GetAttr(i_name + "_scale")); auto weight_scale = boost::get>(op_desc.GetAttr("weight_scale")); weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), @@ -135,12 +136,6 @@ class FcOpConverter : public OpConverter { auto output_name = op_desc.Output("Out").front(); RreplenishLayerAndOutput(layer, "fc", {output_name}, test_mode); - if (enable_int8) { -#if IS_TRT_VERSION_GE(5000) - float out_scale = boost::get(op_desc.GetAttr("out_scale")); - engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); -#endif - } } }; diff --git a/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc index f3c714009f..a5581737ee 100644 --- a/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc @@ -42,6 +42,13 @@ class LeakyReluOpConverter : public OpConverter { engine_, Activation, *input, nvinfer1::ActivationType::kLEAKY_RELU); layer->setAlpha(alpha); output_layer = layer; + + bool enable_int8 = boost::get(op_desc.HasAttr("enable_int8")); + if (enable_int8) { + CHECK(op_desc.HasAttr("X_scale")); + float in_scale = boost::get(op_desc.GetAttr("X_scale")); + engine_->SetTensorDynamicRange(input, in_scale); + } #else platform::CPUPlace place; std::unique_ptr alpha_tensor( diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 09659af7af..846e2154b1 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -160,10 +160,11 @@ class Pool2dOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "pool2d", {output_name}, test_mode); - if (op_desc.HasAttr("out_scale")) { + if (op_desc.HasAttr("enable_int8")) { #if IS_TRT_VERSION_GE(5000) - float out_scale = boost::get(op_desc.GetAttr("out_scale")); - engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); + CHECK(op_desc.HasAttr("X_scale")); + float input_scale = boost::get(op_desc.GetAttr("X_scale")); + engine_->SetTensorDynamicRange(input1, input_scale); #endif } } diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 85722c94b2..6f66e8d972 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -104,12 +104,31 @@ void TensorRTEngine::FreezeNetwork() { for (auto &t : all_t) { if (!quant_dynamic_range_.count(t)) { - LOG(WARNING) + VLOG(3) << "We are in trt int8 mode(not calibration), scale not setted" << " for tensor " << t->getName() << ", this might be ok when trt does not need this range"; } } + std::unordered_set all_out_t_name; + for (int i = 0; i < infer_network_->getNbOutputs(); i++) { + auto *temp = infer_network_->getOutput(i); + temp->setDynamicRange(-1, 1); + all_out_t_name.insert(temp->getName()); + } + + for (int i = 0; i < infer_network_->getNbLayers(); i++) { + auto layer = infer_network_->getLayer(i); + for (int j = 0; j < layer->getNbOutputs(); j++) { + auto *temp_out = layer->getOutput(j); + if (std::find(all_out_t_name.begin(), all_out_t_name.end(), + temp_out->getName()) != all_out_t_name.end()) { + layer->setPrecision(nvinfer1::DataType::kFLOAT); + layer->setOutputType(j, nvinfer1::DataType::kFLOAT); + } + } + } + #endif } } @@ -215,11 +234,6 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name, (scale.size() == 1 || scale.size() == static_cast(w_dims[0])); PADDLE_ENFORCE(valid_scale_size, "TRT int8 quant: invalid scale size"); for (int i = 0; i < weight_tensor->numel(); i++) { - bool is_valid_int8 = - ((weight_data[i] >= -128) && (weight_data[i] <= 127)); - PADDLE_ENFORCE(is_valid_int8, - "We are in anakin subgraph int8 mode, the weight of conv " - "should be in range [-128, 127]"); if (scale.size() == 1) { weight_data[i] *= (scale[0] / 127); } else { diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index fce8daa1d6..aa9f330bc7 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -56,6 +56,10 @@ struct SimpleOpTypeSetTeller : public Teller { }; bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) { + // do not support the op which is labeled the `skip_quant` + if (desc.HasAttr("op_namescope") && + boost::get(desc.GetAttr("op_namescope")) == "/skip_quant_2/") + return false; for (auto& teller : tellers_) { if (op_type == "pool2d" || op_type == "conv2d" || op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") { diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index f8a61b46de..5952c1c4fc 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -328,6 +328,14 @@ if(WITH_GPU AND TENSORRT_FOUND) inference_analysis_test(test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) + + set(TRT_MODEL_QUANT_RESNET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/quant_small_model") + if (NOT EXISTS ${TRT_MODEL_QUANT_RESNET_DIR}) + inference_download_and_uncompress(${INFERENCE_DEMO_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "quant_small_model.tar.gz") + endif() + inference_analysis_test(trt_quant_int8_test SRCS trt_quant_int8_test.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR}) endif() inference_analysis_test(test_analyzer_capi SRCS analyzer_capi_tester.cc diff --git a/paddle/fluid/inference/tests/api/trt_quant_int8_test.cc b/paddle/fluid/inference/tests/api/trt_quant_int8_test.cc new file mode 100644 index 0000000000..e1ce9d5c20 --- /dev/null +++ b/paddle/fluid/inference/tests/api/trt_quant_int8_test.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/inference/tests/api/trt_test_helper.h" + +namespace paddle { +namespace inference { + +TEST(quant_int8, resnet50) { + std::string model_dir = FLAGS_infer_model; + AnalysisConfig config; + config.EnableUseGpu(100, 0); + config.SetModel(model_dir); + config.SwitchUseFeedFetchOps(false); + config.EnableTensorRtEngine(1 << 30, 1, 1, AnalysisConfig::Precision::kInt8, + false, false); + + auto predictor = CreatePaddlePredictor(config); + auto input_names = predictor->GetInputNames(); + int channels = 1; + int height = 3; + int width = 3; + int input_num = channels * height * width * 1; + + float *input = new float[input_num]; + memset(input, 0, input_num * sizeof(float)); + auto input_t = predictor->GetInputTensor(input_names[0]); + input_t->Reshape({1, channels, height, width}); + input_t->copy_from_cpu(input); + + ASSERT_TRUE(predictor->ZeroCopyRun()); +} + +} // namespace inference +} // namespace paddle -- GitLab