未验证 提交 0043fa8c 编写于 作者: C ceci3 提交者: GitHub

[paddle-TRT]support matmul set to int8 in multihead (#34917)

* update ernie int8
上级 c0bdef5d
...@@ -758,7 +758,9 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -758,7 +758,9 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w, Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b, Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) { Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out,
Node* softmax_qk, Node* eltadd0, Node* eltadd1, Node* eltadd2,
Node* matmul_qk) {
auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale")); auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale"));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H) // mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
...@@ -876,19 +878,35 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -876,19 +878,35 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
weight_max = std::max(weight_max, weight_scale2); weight_max = std::max(weight_max, weight_scale2);
multihead_op_desc.SetAttr("weight_scale", weight_max); multihead_op_desc.SetAttr("weight_scale", weight_max);
if (mul0_op_desc->HasAttr("out_threshold")) { auto* add0_op_desc = eltadd0->Op();
auto* add1_op_desc = eltadd1->Op();
auto* add2_op_desc = eltadd2->Op();
if (add0_op_desc->HasAttr("out_threshold")) {
auto out_scale0 = auto out_scale0 =
BOOST_GET_CONST(float, mul0_op_desc->GetAttr("out_threshold")); BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold"));
auto out_scale1 = auto out_scale1 =
BOOST_GET_CONST(float, mul1_op_desc->GetAttr("out_threshold")); BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold"));
auto out_scale2 = auto out_scale2 =
BOOST_GET_CONST(float, mul2_op_desc->GetAttr("out_threshold")); BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold"));
auto out_scale_max = std::max(out_scale0, out_scale1); auto out_scale_max = std::max(out_scale0, out_scale1);
out_scale_max = std::max(out_scale_max, out_scale2); out_scale_max = std::max(out_scale_max, out_scale2);
multihead_op_desc.SetAttr("out_threshold", out_scale_max); multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max);
} }
} }
auto* softmax_qk_op_desc = softmax_qk->Op();
auto* matmul_qk_op_desc = matmul_qk->Op();
if (matmul_qk_op_desc->HasAttr("X_scale")) {
multihead_op_desc.SetAttr("qkv2context_plugin_int8", true);
if (softmax_qk_op_desc->HasAttr("out_threshold")) {
auto qkv_plugin_scale = BOOST_GET_CONST(
float, softmax_qk_op_desc->GetAttr("out_threshold"));
multihead_op_desc.SetAttr("dp_probs", qkv_plugin_scale);
}
} else {
multihead_op_desc.SetAttr("qkv2context_plugin_int8", false);
}
auto* multihead = graph->CreateOpNode(&multihead_op_desc); auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(input0, multihead); IR_NODE_LINK_TO(input0, multihead);
...@@ -990,7 +1008,8 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -990,7 +1008,8 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
} }
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w, fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, scale, scale_out); reshape2_0, reshape2_qkv_out, scale, scale_out, softmax_qk,
eltadd0, eltadd1, eltadd2, matmul_qk);
std::unordered_set<const Node*> marked_nodes({eltadd0, std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1, eltadd1,
......
...@@ -164,10 +164,9 @@ class FcOpConverter : public OpConverter { ...@@ -164,10 +164,9 @@ class FcOpConverter : public OpConverter {
auto* fc_layer_int8 = auto* fc_layer_int8 =
TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output, TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output,
nv_ksize, weight.get(), bias.get()); nv_ksize, weight.get(), bias.get());
engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), out_scale);
auto* fc_after_reshape_int8 = reshape_after_fc( auto* fc_after_reshape_int8 = reshape_after_fc(
fc_layer_int8->getOutput(0), x_dim, x_num_col_dims); fc_layer_int8->getOutput(0), x_dim, x_num_col_dims);
engine_->SetTensorDynamicRange(fc_after_reshape_int8->getOutput(0),
out_scale);
if (activation_type == "relu") { if (activation_type == "relu") {
nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_after_reshape_int8->getOutput(0)), engine_, Activation, *(fc_after_reshape_int8->getOutput(0)),
......
...@@ -42,6 +42,8 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -42,6 +42,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
float* weight_data = nullptr; float* weight_data = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8"); bool enable_int8 = op_desc.HasAttr("enable_int8");
bool qkv2context_plugin_int8 =
BOOST_GET_CONST(bool, op_desc.GetAttr("qkv2context_plugin_int8"));
float in_scale = 0.; float in_scale = 0.;
if (enable_int8) { if (enable_int8) {
...@@ -147,13 +149,16 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -147,13 +149,16 @@ class MultiheadMatMulOpConverter : public OpConverter {
if (enable_int8) { if (enable_int8) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_desc.HasAttr("out_threshold"), true, op_desc.HasAttr("fc_out_threshold"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"must have out threshold in multihead layers in int8 mode")); "must have out threshold in multihead layers in int8 mode"));
float out_scale = float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold"));
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale);
dp_probs = out_scale / 127.0; if (qkv2context_plugin_int8) {
dp_probs =
BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0;
}
} }
auto mask_tensor = engine_->GetITensor("qkv_plugin_mask"); auto mask_tensor = engine_->GetITensor("qkv_plugin_mask");
...@@ -166,16 +171,25 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -166,16 +171,25 @@ class MultiheadMatMulOpConverter : public OpConverter {
: nvinfer1::DataType::kFLOAT); : nvinfer1::DataType::kFLOAT);
if (enable_int8) { if (enable_int8) {
type = static_cast<int>(nvinfer1::DataType::kHALF); type = static_cast<int>(nvinfer1::DataType::kHALF);
if (qkv2context_plugin_int8) {
type = static_cast<int>(nvinfer1::DataType::kINT8);
}
} }
bool has_mask = true; bool has_mask = true;
int var_seqlen = 1; int var_seqlen = 1;
const std::vector<nvinfer1::PluginField> fields{ std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1}, {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}, { "var_seqlen",
{ "dq_probs", &dp_probs, nvinfer1::PluginFieldType::kFLOAT32, 1 }}; &var_seqlen,
nvinfer1::PluginFieldType::kINT32,
1 }};
if (qkv2context_plugin_int8) {
fields.push_back(
{"dq_probs", &dp_probs, nvinfer1::PluginFieldType::kFLOAT32, 1});
}
nvinfer1::PluginFieldCollection* plugin_collection = nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>( static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_collection) + malloc(sizeof(*plugin_collection) +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册