diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 5a97727da3b456981d5fbef8fda053695c3bfc27..5c23e826a2decffdaf9c138cd9bd50e098aa9286 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -758,7 +758,9 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w, Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b, - Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) { + Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out, + Node* softmax_qk, Node* eltadd0, Node* eltadd1, Node* eltadd2, + Node* matmul_qk) { auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale")); // mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H) @@ -876,19 +878,35 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, weight_max = std::max(weight_max, weight_scale2); multihead_op_desc.SetAttr("weight_scale", weight_max); - if (mul0_op_desc->HasAttr("out_threshold")) { + auto* add0_op_desc = eltadd0->Op(); + auto* add1_op_desc = eltadd1->Op(); + auto* add2_op_desc = eltadd2->Op(); + if (add0_op_desc->HasAttr("out_threshold")) { auto out_scale0 = - BOOST_GET_CONST(float, mul0_op_desc->GetAttr("out_threshold")); + BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold")); auto out_scale1 = - BOOST_GET_CONST(float, mul1_op_desc->GetAttr("out_threshold")); + BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold")); auto out_scale2 = - BOOST_GET_CONST(float, mul2_op_desc->GetAttr("out_threshold")); + BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold")); auto out_scale_max = std::max(out_scale0, out_scale1); out_scale_max = std::max(out_scale_max, out_scale2); - multihead_op_desc.SetAttr("out_threshold", out_scale_max); + multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max); } } + auto* softmax_qk_op_desc = softmax_qk->Op(); + auto* matmul_qk_op_desc = matmul_qk->Op(); + if (matmul_qk_op_desc->HasAttr("X_scale")) { + multihead_op_desc.SetAttr("qkv2context_plugin_int8", true); + if (softmax_qk_op_desc->HasAttr("out_threshold")) { + auto qkv_plugin_scale = BOOST_GET_CONST( + float, softmax_qk_op_desc->GetAttr("out_threshold")); + multihead_op_desc.SetAttr("dp_probs", qkv_plugin_scale); + } + } else { + multihead_op_desc.SetAttr("qkv2context_plugin_int8", false); + } + auto* multihead = graph->CreateOpNode(&multihead_op_desc); IR_NODE_LINK_TO(input0, multihead); @@ -990,7 +1008,8 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, } fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, - reshape2_0, reshape2_qkv_out, scale, scale_out); + reshape2_0, reshape2_qkv_out, scale, scale_out, softmax_qk, + eltadd0, eltadd1, eltadd2, matmul_qk); std::unordered_set marked_nodes({eltadd0, eltadd1, diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index ef50f3db42c6f54b146e01cb8124d344746af212..666a1a914666cf5fd904ad9b79b1ffa212702526 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -164,10 +164,9 @@ class FcOpConverter : public OpConverter { auto* fc_layer_int8 = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output, nv_ksize, weight.get(), bias.get()); + engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), out_scale); auto* fc_after_reshape_int8 = reshape_after_fc( fc_layer_int8->getOutput(0), x_dim, x_num_col_dims); - engine_->SetTensorDynamicRange(fc_after_reshape_int8->getOutput(0), - out_scale); if (activation_type == "relu") { nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( engine_, Activation, *(fc_after_reshape_int8->getOutput(0)), diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index d05c9019a29d3980c701a55629b1deb04a1ddb0b..2a9b015ce982cb0f6982058a4e9280f951f45d62 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -42,6 +42,8 @@ class MultiheadMatMulOpConverter : public OpConverter { float* weight_data = nullptr; bool enable_int8 = op_desc.HasAttr("enable_int8"); + bool qkv2context_plugin_int8 = + BOOST_GET_CONST(bool, op_desc.GetAttr("qkv2context_plugin_int8")); float in_scale = 0.; if (enable_int8) { @@ -147,13 +149,16 @@ class MultiheadMatMulOpConverter : public OpConverter { if (enable_int8) { PADDLE_ENFORCE_EQ( - op_desc.HasAttr("out_threshold"), true, + op_desc.HasAttr("fc_out_threshold"), true, platform::errors::InvalidArgument( "must have out threshold in multihead layers in int8 mode")); float out_scale = - BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); - dp_probs = out_scale / 127.0; + if (qkv2context_plugin_int8) { + dp_probs = + BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0; + } } auto mask_tensor = engine_->GetITensor("qkv_plugin_mask"); @@ -166,16 +171,25 @@ class MultiheadMatMulOpConverter : public OpConverter { : nvinfer1::DataType::kFLOAT); if (enable_int8) { type = static_cast(nvinfer1::DataType::kHALF); + if (qkv2context_plugin_int8) { + type = static_cast(nvinfer1::DataType::kINT8); + } } bool has_mask = true; int var_seqlen = 1; - const std::vector fields{ + std::vector fields{ {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1}, {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, - {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}, - { "dq_probs", &dp_probs, nvinfer1::PluginFieldType::kFLOAT32, 1 }}; + { "var_seqlen", + &var_seqlen, + nvinfer1::PluginFieldType::kINT32, + 1 }}; + if (qkv2context_plugin_int8) { + fields.push_back( + {"dq_probs", &dp_probs, nvinfer1::PluginFieldType::kFLOAT32, 1}); + } nvinfer1::PluginFieldCollection* plugin_collection = static_cast( malloc(sizeof(*plugin_collection) +