From 2953b708a03d023b6b6b1fecde7ac431f8f48a94 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Mon, 31 Oct 2022 14:23:40 +0800 Subject: [PATCH] feat: add int8 support for vit (#47330) * feat: add int8 support for vit * test:add test --- .../framework/ir/vit_attention_fuse_pass.cc | 26 +++++++++++++ .../tensorrt/convert/multihead_matmul_op.cc | 38 +++++++++++++++---- .../test_trt_convert_multihead_matmul.py | 12 +++++- 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc b/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc index 4f61b03010..3ff91e0bcb 100644 --- a/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc +++ b/paddle/fluid/framework/ir/vit_attention_fuse_pass.cc @@ -56,6 +56,22 @@ namespace paddle { namespace framework { namespace ir { +bool HasScale(OpDesc* const op_ptr, + std::string* name, + std::string regexp = "Input_scale_") { + name->clear(); + std::unordered_map attr_map = op_ptr->GetAttrMap(); + std::unordered_map::iterator iter; + int len = regexp.size(); + for (iter = attr_map.begin(); iter != attr_map.end(); iter++) { + if (regexp == iter->first.substr(0, len)) { + *name = iter->first; + return true; + } + } + return false; +} + void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const { GraphPatternDetector gpd; const std::string pattern_name = "vit_attention_fuse"; @@ -103,6 +119,16 @@ void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const { float alpha = PADDLE_GET_CONST(float, scale1_op->Op()->GetAttr("scale")); desc.SetAttr("alpha", alpha); + // int8 for fc + std::string scale_name; + if (HasScale(matmul0_op->Op(), &scale_name)) { + desc.SetAttr("Input_scale", matmul0_op->Op()->GetAttr(scale_name)); + } + if (HasScale(elementwise0_op->Op(), &scale_name, "Out")) { + desc.SetAttr("fc_out_threshold", + elementwise0_op->Op()->GetAttr(scale_name)); + } + // Create a new node for the fused op. auto vit_attention_node = graph->CreateOpNode(&desc); diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index f997c8bd1f..0515cb513d 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -398,13 +398,37 @@ class MultiheadMatMulOpConverter : public OpConverter { // add fc layer nvinfer1::ILayer* fc_layer = nullptr; - fc_layer = - TRT_ENGINE_ADD_LAYER(engine_, - FullyConnected, - *reshape_before_fc_layer->getOutput(0), - n, - weight, - bias); + if (op_desc.HasAttr("Input_scale")) { + engine_->SetTensorDynamicRange( + reshape_before_fc_layer->getOutput(0), in_scale); + nvinfer1::DimsHW nv_ksize(1, 1); + fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, + Convolution, + *reshape_before_fc_layer->getOutput(0), + n, + nv_ksize, + weight, + bias); + PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), + true, + platform::errors::InvalidArgument( + "must have out threshold in multihead layers " + "in int8 mode")); + float out_scale = + PADDLE_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); + engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); + } else { + fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, + FullyConnected, + *reshape_before_fc_layer->getOutput(0), + n, + weight, + bias); + } + fc_layer->setName( + ("multihead_mamul_fc(Output: " + output_name + ")").c_str()); // add shuffle for CustomQKVToContextPluginDynamic layer auto* reshape_after_fc_layer = diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py index 6889f88fa4..fa1cb51e7f 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py @@ -818,7 +818,11 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): "Y": ["matmul1_weight"], }, "op_outputs": {"Out": ["matmul1_output"]}, - "op_attrs": {"trans_x": False, "trans_y": False}, + "op_attrs": { + "trans_x": False, + "trans_y": False, + "Input_scale_layer": 1.0, + }, }, { "op_type": "elementwise_add", @@ -832,6 +836,7 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): "Scale_x": 1.0, "Scale_y": 1.0, "axis": 2, + "Out": 1.0, }, }, { @@ -1035,6 +1040,11 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.workspace_size = 2013265920 + self.trt_param.precision = paddle_infer.PrecisionType.Int8 + yield self.create_inference_config(), generate_trt_nodes_num(), ( + 1e-3, + 1e-3, + ) self.trt_param.precision = paddle_infer.PrecisionType.Half yield self.create_inference_config(), generate_trt_nodes_num(), ( 1e-3, -- GitLab