From f7c629d91d3576f9ac575a088a7a8a7f80f6a120 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Mon, 9 Dec 2019 01:35:38 +0800 Subject: [PATCH] Revert "CHERRY_PICK: TRT int8: refine trt int8 for dynamic range set (#21112) (#21449)" (#21619) This reverts commit 0473cdb8b412d8ffaefa043e48e3578c489b4f86. --- .../ir/delete_quant_dequant_op_pass.cc | 21 -------- paddle/fluid/framework/ir/fc_fuse_pass.cc | 2 +- .../ir/quant_conv2d_dequant_fuse_pass.cc | 6 +-- .../inference/api/paddle_pass_builder.cc | 7 ++- .../inference/tensorrt/convert/conv2d_op.cc | 10 +++- .../tensorrt/convert/elementwise_op.cc | 17 +++---- .../fluid/inference/tensorrt/convert/fc_op.cc | 9 +++- .../tensorrt/convert/leaky_relu_op.cc | 7 --- .../inference/tensorrt/convert/pool2d_op.cc | 7 ++- paddle/fluid/inference/tensorrt/engine.cc | 26 +++------- paddle/fluid/inference/tensorrt/op_teller.cc | 4 -- .../fluid/inference/tests/api/CMakeLists.txt | 8 --- .../tests/api/trt_quant_int8_test.cc | 50 ------------------- 13 files changed, 36 insertions(+), 138 deletions(-) delete mode 100644 paddle/fluid/inference/tests/api/trt_quant_int8_test.cc diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc index 4dfbd5e00c..3d4df87ab7 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc @@ -39,7 +39,6 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { patterns::DeleteQuantDequantOpPattern pattern(gpd.mutable_pattern(), pattern_name); pattern(); - auto* scope = param_scope(); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -48,29 +47,10 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { std::string any_op_out_name = any_op_out->Var()->Name(); std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name(); - std::string input_scale_var_name = - quant_dequant_op->Op()->Input("InScale").front(); - const LoDTensor& input_scale_tensor = - scope->FindVar(input_scale_var_name)->Get(); - - const float* input_scale_data = input_scale_tensor.data(); - float input_scale = input_scale_data[0]; auto* any_op2_desc = any_op2->Op(); // auto input_args_names = any_op2_desc->InputArgumentNames(); auto var_map = any_op2_desc->Inputs(); - std::string arg_name = ""; - for (auto& name_m : var_map) { - if (std::find(name_m.second.begin(), name_m.second.end(), - quant_dequant_op_out_name) != name_m.second.end()) { - arg_name = name_m.first; - } - } - CHECK(arg_name.size() > 0) << "can not find the input " - << quant_dequant_op_out_name; - any_op2_desc->SetAttr("enable_int8", true); - any_op2_desc->SetAttr(arg_name + "_scale", input_scale); - // modify the any_op2's inputs for (auto& name_m : var_map) { if (std::find(name_m.second.begin(), name_m.second.end(), quant_dequant_op_out_name) != name_m.second.end()) { @@ -85,7 +65,6 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { any_op2_desc->Flush(); } } - any_op2_desc->Flush(); // Delete the unneeded nodes. GraphSafeRemoveNodes(graph, {quant_dequant_op, quant_dequant_op_out, diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 1cb42fbe43..b53e6a250c 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -99,7 +99,7 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { auto* mul_op_desc = mul->Op(); if (mul_op_desc->HasAttr("enable_int8")) { desc.SetAttr("enable_int8", mul_op_desc->GetAttr("enable_int8")); - desc.SetAttr("Input_scale", mul_op_desc->GetAttr("X_scale")); + desc.SetAttr("input_scale", mul_op_desc->GetAttr("input_scale")); desc.SetAttr("weight_scale", mul_op_desc->GetAttr("weight_scale")); if (mul_op_desc->HasAttr("out_scale")) desc.SetAttr("out_scale", mul_op_desc->GetAttr("out_scale")); diff --git a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc index d0ca452238..62fba440ed 100644 --- a/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc +++ b/paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc @@ -140,24 +140,22 @@ void RunQuantDequant(ir::Graph* graph, Scope* scope, int times, framework::OpDesc new_op_desc(base_op_desc, nullptr); new_op_desc.SetType(quantized_op_type); - new_op_desc.SetAttr("enable_int8", true); if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" || quantized_op_type == "depthwise_conv2d") { new_op_desc.SetInput("Input", {new_input}); - new_op_desc.SetAttr("Input_scale", input_scale); new_op_desc.SetOutput("Output", {new_output}); } else if (quantized_op_type == "fc") { new_op_desc.SetInput("Input", {new_input}); - new_op_desc.SetAttr("Input_scale", input_scale); new_op_desc.SetOutput("Out", {new_output}); } else if (quantized_op_type == "mul") { new_op_desc.SetInput("X", {new_input}); - new_op_desc.SetAttr("X_scale", input_scale); new_op_desc.SetOutput("Out", {new_output}); } + new_op_desc.SetAttr("enable_int8", true); + new_op_desc.SetAttr("input_scale", input_scale); new_op_desc.SetAttr("weight_scale", weight_scale); new_op_desc.Flush(); auto* new_op = graph->CreateOpNode(&new_op_desc); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 459fef40f1..959c0c00c2 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -76,10 +76,9 @@ const std::vector kTRTSubgraphPasses({ "shuffle_channel_detect_pass", // "quant_conv2d_dequant_fuse_pass", // "delete_quant_dequant_op_pass", // - "conv_bn_fuse_pass", // - "fc_fuse_pass", // - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + // "fc_fuse_pass", // + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 "conv_elementwise_add_act_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index fd60d844ac..73bfa800f0 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -40,8 +40,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, if (enable_int8) { #if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("Input_scale")); - float in_scale = boost::get(op_desc.GetAttr("Input_scale")); + float in_scale = boost::get(op_desc.GetAttr("input_scale")); auto weight_scale = boost::get>(op_desc.GetAttr("weight_scale")); weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, @@ -90,6 +89,13 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, layer->getOutput(0)->setName(output_name.c_str()); engine->SetITensor(output_name, layer->getOutput(0)); +#if IS_TRT_VERSION_GE(5000) + if (enable_int8) { + float output_scale = boost::get(op_desc.GetAttr("out_scale")); + engine->SetTensorDynamicRange(layer->getOutput(0), output_scale); + } +#endif + if (test_mode) { engine->DeclareOutput(output_name); } diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 840369976d..c61dd753a3 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -110,11 +110,10 @@ class ElementwiseWeightOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "elementwise_" + op_type_, {output_name}, test_mode); - if (op_desc.HasAttr("enable_int8")) { + if (op_desc.HasAttr("out_scale")) { #if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - float x_scale = boost::get(op_desc.GetAttr("X_scale")); - engine_->SetTensorDynamicRange(X, x_scale); + float out_scale = boost::get(op_desc.GetAttr("out_scale")); + engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); #endif } } @@ -170,14 +169,10 @@ class ElementwiseTensorOpConverter : public OpConverter { layer = plugin_layer; } RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); - if (op_desc.HasAttr("enable_int8")) { + if (op_desc.HasAttr("out_scale")) { #if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - CHECK(op_desc.HasAttr("Y_scale")); - float x_scale = boost::get(op_desc.GetAttr("X_scale")); - float y_scale = boost::get(op_desc.GetAttr("Y_scale")); - engine_->SetTensorDynamicRange(X, x_scale); - engine_->SetTensorDynamicRange(Y, y_scale); + float out_scale = boost::get(op_desc.GetAttr("out_scale")); + engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); #endif } } diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index ec21bb5534..ea108d6a07 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -77,8 +77,7 @@ class FcOpConverter : public OpConverter { bool enable_int8 = boost::get(op_desc.HasAttr("enable_int8")); if (enable_int8) { #if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr(i_name + "_scale")); - float in_scale = boost::get(op_desc.GetAttr(i_name + "_scale")); + float in_scale = boost::get(op_desc.GetAttr("input_scale")); auto weight_scale = boost::get>(op_desc.GetAttr("weight_scale")); weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), @@ -136,6 +135,12 @@ class FcOpConverter : public OpConverter { auto output_name = op_desc.Output("Out").front(); RreplenishLayerAndOutput(layer, "fc", {output_name}, test_mode); + if (enable_int8) { +#if IS_TRT_VERSION_GE(5000) + float out_scale = boost::get(op_desc.GetAttr("out_scale")); + engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); +#endif + } } }; diff --git a/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc index a5581737ee..f3c714009f 100644 --- a/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc @@ -42,13 +42,6 @@ class LeakyReluOpConverter : public OpConverter { engine_, Activation, *input, nvinfer1::ActivationType::kLEAKY_RELU); layer->setAlpha(alpha); output_layer = layer; - - bool enable_int8 = boost::get(op_desc.HasAttr("enable_int8")); - if (enable_int8) { - CHECK(op_desc.HasAttr("X_scale")); - float in_scale = boost::get(op_desc.GetAttr("X_scale")); - engine_->SetTensorDynamicRange(input, in_scale); - } #else platform::CPUPlace place; std::unique_ptr alpha_tensor( diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 846e2154b1..09659af7af 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -160,11 +160,10 @@ class Pool2dOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "pool2d", {output_name}, test_mode); - if (op_desc.HasAttr("enable_int8")) { + if (op_desc.HasAttr("out_scale")) { #if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - float input_scale = boost::get(op_desc.GetAttr("X_scale")); - engine_->SetTensorDynamicRange(input1, input_scale); + float out_scale = boost::get(op_desc.GetAttr("out_scale")); + engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); #endif } } diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 6f66e8d972..85722c94b2 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -104,31 +104,12 @@ void TensorRTEngine::FreezeNetwork() { for (auto &t : all_t) { if (!quant_dynamic_range_.count(t)) { - VLOG(3) + LOG(WARNING) << "We are in trt int8 mode(not calibration), scale not setted" << " for tensor " << t->getName() << ", this might be ok when trt does not need this range"; } } - std::unordered_set all_out_t_name; - for (int i = 0; i < infer_network_->getNbOutputs(); i++) { - auto *temp = infer_network_->getOutput(i); - temp->setDynamicRange(-1, 1); - all_out_t_name.insert(temp->getName()); - } - - for (int i = 0; i < infer_network_->getNbLayers(); i++) { - auto layer = infer_network_->getLayer(i); - for (int j = 0; j < layer->getNbOutputs(); j++) { - auto *temp_out = layer->getOutput(j); - if (std::find(all_out_t_name.begin(), all_out_t_name.end(), - temp_out->getName()) != all_out_t_name.end()) { - layer->setPrecision(nvinfer1::DataType::kFLOAT); - layer->setOutputType(j, nvinfer1::DataType::kFLOAT); - } - } - } - #endif } } @@ -234,6 +215,11 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name, (scale.size() == 1 || scale.size() == static_cast(w_dims[0])); PADDLE_ENFORCE(valid_scale_size, "TRT int8 quant: invalid scale size"); for (int i = 0; i < weight_tensor->numel(); i++) { + bool is_valid_int8 = + ((weight_data[i] >= -128) && (weight_data[i] <= 127)); + PADDLE_ENFORCE(is_valid_int8, + "We are in anakin subgraph int8 mode, the weight of conv " + "should be in range [-128, 127]"); if (scale.size() == 1) { weight_data[i] *= (scale[0] / 127); } else { diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index aa9f330bc7..fce8daa1d6 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -56,10 +56,6 @@ struct SimpleOpTypeSetTeller : public Teller { }; bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) { - // do not support the op which is labeled the `skip_quant` - if (desc.HasAttr("op_namescope") && - boost::get(desc.GetAttr("op_namescope")) == "/skip_quant_2/") - return false; for (auto& teller : tellers_) { if (op_type == "pool2d" || op_type == "conv2d" || op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") { diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 5952c1c4fc..f8a61b46de 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -328,14 +328,6 @@ if(WITH_GPU AND TENSORRT_FOUND) inference_analysis_test(test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) - - set(TRT_MODEL_QUANT_RESNET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/quant_small_model") - if (NOT EXISTS ${TRT_MODEL_QUANT_RESNET_DIR}) - inference_download_and_uncompress(${INFERENCE_DEMO_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "quant_small_model.tar.gz") - endif() - inference_analysis_test(trt_quant_int8_test SRCS trt_quant_int8_test.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR}) endif() inference_analysis_test(test_analyzer_capi SRCS analyzer_capi_tester.cc diff --git a/paddle/fluid/inference/tests/api/trt_quant_int8_test.cc b/paddle/fluid/inference/tests/api/trt_quant_int8_test.cc deleted file mode 100644 index e1ce9d5c20..0000000000 --- a/paddle/fluid/inference/tests/api/trt_quant_int8_test.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/inference/tests/api/trt_test_helper.h" - -namespace paddle { -namespace inference { - -TEST(quant_int8, resnet50) { - std::string model_dir = FLAGS_infer_model; - AnalysisConfig config; - config.EnableUseGpu(100, 0); - config.SetModel(model_dir); - config.SwitchUseFeedFetchOps(false); - config.EnableTensorRtEngine(1 << 30, 1, 1, AnalysisConfig::Precision::kInt8, - false, false); - - auto predictor = CreatePaddlePredictor(config); - auto input_names = predictor->GetInputNames(); - int channels = 1; - int height = 3; - int width = 3; - int input_num = channels * height * width * 1; - - float *input = new float[input_num]; - memset(input, 0, input_num * sizeof(float)); - auto input_t = predictor->GetInputTensor(input_names[0]); - input_t->Reshape({1, channels, height, width}); - input_t->copy_from_cpu(input); - - ASSERT_TRUE(predictor->ZeroCopyRun()); -} - -} // namespace inference -} // namespace paddle -- GitLab