未验证 提交 2953b708 编写于 作者: F feng_shuai 提交者: GitHub

feat: add int8 support for vit (#47330)

* feat: add int8 support for vit

* test:add test
上级 34d13d6a
...@@ -56,6 +56,22 @@ namespace paddle { ...@@ -56,6 +56,22 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
bool HasScale(OpDesc* const op_ptr,
std::string* name,
std::string regexp = "Input_scale_") {
name->clear();
std::unordered_map<std::string, Attribute> attr_map = op_ptr->GetAttrMap();
std::unordered_map<std::string, Attribute>::iterator iter;
int len = regexp.size();
for (iter = attr_map.begin(); iter != attr_map.end(); iter++) {
if (regexp == iter->first.substr(0, len)) {
*name = iter->first;
return true;
}
}
return false;
}
void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const { void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
const std::string pattern_name = "vit_attention_fuse"; const std::string pattern_name = "vit_attention_fuse";
...@@ -103,6 +119,16 @@ void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -103,6 +119,16 @@ void VitAttentionFusePass::ApplyImpl(ir::Graph* graph) const {
float alpha = PADDLE_GET_CONST(float, scale1_op->Op()->GetAttr("scale")); float alpha = PADDLE_GET_CONST(float, scale1_op->Op()->GetAttr("scale"));
desc.SetAttr("alpha", alpha); desc.SetAttr("alpha", alpha);
// int8 for fc
std::string scale_name;
if (HasScale(matmul0_op->Op(), &scale_name)) {
desc.SetAttr("Input_scale", matmul0_op->Op()->GetAttr(scale_name));
}
if (HasScale(elementwise0_op->Op(), &scale_name, "Out")) {
desc.SetAttr("fc_out_threshold",
elementwise0_op->Op()->GetAttr(scale_name));
}
// Create a new node for the fused op. // Create a new node for the fused op.
auto vit_attention_node = graph->CreateOpNode(&desc); auto vit_attention_node = graph->CreateOpNode(&desc);
......
...@@ -398,13 +398,37 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -398,13 +398,37 @@ class MultiheadMatMulOpConverter : public OpConverter {
// add fc layer // add fc layer
nvinfer1::ILayer* fc_layer = nullptr; nvinfer1::ILayer* fc_layer = nullptr;
fc_layer = if (op_desc.HasAttr("Input_scale")) {
TRT_ENGINE_ADD_LAYER(engine_, engine_->SetTensorDynamicRange(
FullyConnected, reshape_before_fc_layer->getOutput(0), in_scale);
*reshape_before_fc_layer->getOutput(0), nvinfer1::DimsHW nv_ksize(1, 1);
n, fc_layer =
weight, TRT_ENGINE_ADD_LAYER(engine_,
bias); Convolution,
*reshape_before_fc_layer->getOutput(0),
n,
nv_ksize,
weight,
bias);
PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"),
true,
platform::errors::InvalidArgument(
"must have out threshold in multihead layers "
"in int8 mode"));
float out_scale =
PADDLE_GET_CONST(float, op_desc.GetAttr("fc_out_threshold"));
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale);
} else {
fc_layer =
TRT_ENGINE_ADD_LAYER(engine_,
FullyConnected,
*reshape_before_fc_layer->getOutput(0),
n,
weight,
bias);
}
fc_layer->setName(
("multihead_mamul_fc(Output: " + output_name + ")").c_str());
// add shuffle for CustomQKVToContextPluginDynamic layer // add shuffle for CustomQKVToContextPluginDynamic layer
auto* reshape_after_fc_layer = auto* reshape_after_fc_layer =
......
...@@ -818,7 +818,11 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -818,7 +818,11 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
"Y": ["matmul1_weight"], "Y": ["matmul1_weight"],
}, },
"op_outputs": {"Out": ["matmul1_output"]}, "op_outputs": {"Out": ["matmul1_output"]},
"op_attrs": {"trans_x": False, "trans_y": False}, "op_attrs": {
"trans_x": False,
"trans_y": False,
"Input_scale_layer": 1.0,
},
}, },
{ {
"op_type": "elementwise_add", "op_type": "elementwise_add",
...@@ -832,6 +836,7 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -832,6 +836,7 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
"Scale_x": 1.0, "Scale_x": 1.0,
"Scale_y": 1.0, "Scale_y": 1.0,
"axis": 2, "axis": 2,
"Out": 1.0,
}, },
}, },
{ {
...@@ -1035,6 +1040,11 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -1035,6 +1040,11 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
# for dynamic_shape # for dynamic_shape
generate_dynamic_shape(attrs) generate_dynamic_shape(attrs)
self.trt_param.workspace_size = 2013265920 self.trt_param.workspace_size = 2013265920
self.trt_param.precision = paddle_infer.PrecisionType.Int8
yield self.create_inference_config(), generate_trt_nodes_num(), (
1e-3,
1e-3,
)
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(), ( yield self.create_inference_config(), generate_trt_nodes_num(), (
1e-3, 1e-3,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册