未验证 提交 6da043eb 编写于 作者: C ceci3 提交者: GitHub

support ernie trt-int8 for inference (#32232)

* support ernie trt-int8 for inference

* fix reshape
上级 fabdb43c
...@@ -299,6 +299,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope ...@@ -299,6 +299,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()}); new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon", new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon")); end_patter_layernorms[k]->Op()->GetAttr("epsilon"));
if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
}
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc); auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) { for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
......
...@@ -535,6 +535,38 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -535,6 +535,38 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
multihead_op_desc.SetAttr("alpha", scale_attr); multihead_op_desc.SetAttr("alpha", scale_attr);
multihead_op_desc.SetAttr("head_number", head_number); multihead_op_desc.SetAttr("head_number", head_number);
auto* mul0_op_desc = mul0->Op();
auto* mul1_op_desc = mul1->Op();
auto* mul2_op_desc = mul2->Op();
if (mul0_op_desc->HasAttr("enable_int8")) {
multihead_op_desc.SetAttr("enable_int8",
mul0_op_desc->GetAttr("enable_int8"));
// all mul op has same input.
multihead_op_desc.SetAttr("Input_scale",
mul0_op_desc->GetAttr("X_scale"));
auto weight_scale0 = BOOST_GET_CONST(
std::vector<float>, mul0_op_desc->GetAttr("weight_scale"));
auto weight_scale1 = BOOST_GET_CONST(
std::vector<float>, mul1_op_desc->GetAttr("weight_scale"));
auto weight_scale2 = BOOST_GET_CONST(
std::vector<float>, mul2_op_desc->GetAttr("weight_scale"));
auto weight_max = std::max(weight_scale0, weight_scale1);
weight_max = std::max(weight_max, weight_scale2);
multihead_op_desc.SetAttr("weight_scale", weight_max);
if (mul0_op_desc->HasAttr("out_threshold")) {
auto out_scale0 =
BOOST_GET_CONST(float, mul0_op_desc->GetAttr("out_threshold"));
auto out_scale1 =
BOOST_GET_CONST(float, mul1_op_desc->GetAttr("out_threshold"));
auto out_scale2 =
BOOST_GET_CONST(float, mul2_op_desc->GetAttr("out_threshold"));
auto out_scale_max = std::max(out_scale0, out_scale1);
out_scale_max = std::max(out_scale_max, out_scale2);
multihead_op_desc.SetAttr("out_threshold", out_scale_max);
}
}
auto* multihead = graph->CreateOpNode(&multihead_op_desc); auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(input0, multihead); IR_NODE_LINK_TO(input0, multihead);
......
...@@ -153,6 +153,10 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -153,6 +153,10 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
new_desc.SetInput("Scale", {layer_norm_scale->Name()}); new_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_desc.SetInput("Bias", {layer_norm_bias->Name()}); new_desc.SetInput("Bias", {layer_norm_bias->Name()});
if (elementwise->Op()->HasAttr("out_threshold")) {
new_desc.SetAttr("enable_int8", true);
}
// outputs // outputs
new_desc.SetOutput("Out", {layer_norm_out->Name()}); new_desc.SetOutput("Out", {layer_norm_out->Name()});
......
...@@ -31,7 +31,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -31,7 +31,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
VLOG(4) << "convert fluid swish op to tensorrt layer"; VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto id_names = op_desc.Input("Ids"); auto id_names = op_desc.Input("Ids");
...@@ -89,10 +89,14 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -89,10 +89,14 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
int64_t bias_size = framework::product(bias_dims); int64_t bias_size = framework::product(bias_dims);
int64_t scale_size = framework::product(scale_dims); int64_t scale_size = framework::product(scale_dims);
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8");
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
if (engine_->use_oss()) { if (engine_->use_oss()) {
int output_fp16 = static_cast<int>((engine_->WithFp16() == 1) ? 1 : 0); int output_fp16 = static_cast<int>((engine_->WithFp16() == 1) ? 1 : 0);
if (enable_int8) {
output_fp16 = 1;
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
output_fp16, 1, output_fp16, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -106,8 +106,22 @@ class FcOpConverter : public OpConverter { ...@@ -106,8 +106,22 @@ class FcOpConverter : public OpConverter {
auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output, auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output,
TensorRTEngine::Weight& weight, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) { TensorRTEngine::Weight& bias) {
auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *inputs, nvinfer1::ILayer* fc_layer = nullptr;
n_output, weight.get(), bias.get()); if (enable_int8) {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in fc layers in int8 mode"));
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output,
nv_ksize, weight.get(), bias.get());
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale);
} else {
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *inputs,
n_output, weight.get(), bias.get());
}
auto output_name = op_desc.Output("Out").front(); auto output_name = op_desc.Output("Out").front();
if (activation_type == "relu") { if (activation_type == "relu") {
...@@ -229,13 +243,24 @@ class FcOpConverter : public OpConverter { ...@@ -229,13 +243,24 @@ class FcOpConverter : public OpConverter {
"dims equals to 4, the last dim of input must be 1, but got %d", "dims equals to 4, the last dim of input must be 1, but got %d",
input_d[3])); input_d[3]));
} }
for (int i = 0; i < 3; i++) { if (enable_int8) {
if (i < input_dims) { reshape_dim3[0] = 1;
reshape_dim3[i] = input_d[i]; for (int i = 0; i < 3; i++) {
} else { reshape_dim3[0] *= input_d[i];
reshape_dim3[i] = 1; if (i > 0) {
reshape_dim3[i] = 1;
}
}
} else {
for (int i = 0; i < 3; i++) {
if (i < input_dims) {
reshape_dim3[i] = input_d[i];
} else {
reshape_dim3[i] = 1;
}
} }
} }
nvinfer1::Dims3 reshape_dim(reshape_dim3[0], reshape_dim3[1], nvinfer1::Dims3 reshape_dim(reshape_dim3[0], reshape_dim3[1],
reshape_dim3[2]); reshape_dim3[2]);
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X); auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X);
...@@ -249,11 +274,25 @@ class FcOpConverter : public OpConverter { ...@@ -249,11 +274,25 @@ class FcOpConverter : public OpConverter {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Invalid dimensions. When x_num_col_dims equals to " "Invalid dimensions. When x_num_col_dims equals to "
"2, input_dims should not be 1")); "2, input_dims should not be 1"));
for (int i = 0; i < 4; i++) {
if (i < input_dims) { if (enable_int8) {
reshape_dim4[i] = input_d[i]; for (int i = 0; i < 4; i++) {
} else { if (i == 0) {
reshape_dim4[i] = 1; reshape_dim4[i] = input_d[i];
} else {
reshape_dim4[i] = 1;
if (i < input_dims) {
reshape_dim4[1] *= input_d[i];
}
}
}
} else {
for (int i = 0; i < 4; i++) {
if (i < input_dims) {
reshape_dim4[i] = input_d[i];
} else {
reshape_dim4[i] = 1;
}
} }
} }
nvinfer1::Dims4 reshape_dim(reshape_dim4[0], reshape_dim4[1], nvinfer1::Dims4 reshape_dim(reshape_dim4[0], reshape_dim4[1],
......
...@@ -40,8 +40,25 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -40,8 +40,25 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto* bias_v = scope.FindVar(bias_name); auto* bias_v = scope.FindVar(bias_name);
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>(); auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
float* weight_data = float* weight_data = nullptr;
engine_->GetWeightCPUData(weight_name, weight_t, false); bool enable_int8 = op_desc.HasAttr("enable_int8");
float in_scale = 0.;
if (enable_int8) {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("Input_scale"), true,
platform::errors::InvalidArgument(
"must have input scale in multihead layers in int8 mode"));
in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data =
engine_->GetWeightCPUData(weight_name, weight_t, true, weight_scale);
engine_->SetTensorDynamicRange(input, in_scale);
} else {
weight_data = engine_->GetWeightCPUData(weight_name, weight_t, false);
}
float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t, false); float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t, false);
std::vector<float> weight_data_tmp; std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(weight_t->numel()); weight_data_tmp.reserve(weight_t->numel());
...@@ -117,8 +134,27 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -117,8 +134,27 @@ class MultiheadMatMulOpConverter : public OpConverter {
static_cast<void*>(bias_data), static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())}; static_cast<int32_t>(bias_t->numel())};
auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, nvinfer1::ILayer* fc_layer = nullptr;
n, weight, bias); float dp_probs = 1.0 / 127.0;
if (enable_int8) {
nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n,
nv_ksize, weight, bias);
} else {
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n,
weight, bias);
}
if (enable_int8) {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in multihead layers in int8 mode"));
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale);
dp_probs = out_scale / 127.0;
}
auto mask_tensor = engine_->GetITensor("qkv_plugin_mask"); auto mask_tensor = engine_->GetITensor("qkv_plugin_mask");
...@@ -128,6 +164,9 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -128,6 +164,9 @@ class MultiheadMatMulOpConverter : public OpConverter {
int type = static_cast<int>((engine_->WithFp16() == 1) int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF ? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT); : nvinfer1::DataType::kFLOAT);
if (enable_int8) {
type = static_cast<int>(nvinfer1::DataType::kHALF);
}
bool has_mask = true; bool has_mask = true;
int var_seqlen = 1; int var_seqlen = 1;
const std::vector<nvinfer1::PluginField> fields{ const std::vector<nvinfer1::PluginField> fields{
...@@ -136,7 +175,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -136,7 +175,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}, {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1},
}; { "dq_probs", &dp_probs, nvinfer1::PluginFieldType::kFLOAT32, 1 }};
nvinfer1::PluginFieldCollection* plugin_collection = nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>( static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_collection) + malloc(sizeof(*plugin_collection) +
......
...@@ -49,6 +49,7 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -49,6 +49,7 @@ class SkipLayerNormOpConverter : public OpConverter {
auto* scale = get_persistable_data("Scale", &scale_dims); auto* scale = get_persistable_data("Scale", &scale_dims);
int bias_size = framework::product(bias_dims); int bias_size = framework::product(bias_dims);
int scale_size = framework::product(scale_dims); int scale_size = framework::product(scale_dims);
bool enable_int8 = op_desc.HasAttr("enable_int8");
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
...@@ -62,6 +63,10 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -62,6 +63,10 @@ class SkipLayerNormOpConverter : public OpConverter {
int ld = input1->getDimensions().d[2]; // hidden dimension int ld = input1->getDimensions().d[2]; // hidden dimension
assert(ld > 0); assert(ld > 0);
if (enable_int8) {
type = static_cast<int>(nvinfer1::DataType::kHALF);
}
const std::vector<nvinfer1::PluginField> fields{ const std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"ld", &ld, nvinfer1::PluginFieldType::kINT32, 1}, {"ld", &ld, nvinfer1::PluginFieldType::kINT32, 1},
......
...@@ -31,6 +31,12 @@ class SliceOpConverter : public OpConverter { ...@@ -31,6 +31,12 @@ class SliceOpConverter : public OpConverter {
// Declare inputs // Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("Input")[0]); auto* input = engine_->GetITensor(op_desc.Input("Input")[0]);
if (op_desc.HasAttr("out_threshold")) {
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(input, out_scale);
}
std::vector<int> axes = std::vector<int> axes =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("axes")); BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
std::vector<int> starts = std::vector<int> starts =
......
...@@ -45,6 +45,11 @@ class StackOpConverter : public OpConverter { ...@@ -45,6 +45,11 @@ class StackOpConverter : public OpConverter {
for (int i = 0; i < input_num; ++i) { for (int i = 0; i < input_num; ++i) {
inputs[i] = engine_->GetITensor(input[i]); inputs[i] = engine_->GetITensor(input[i]);
if (op_desc.HasAttr("out_threshold")) {
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(inputs[i], out_scale);
}
} }
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis")); int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
......
...@@ -45,6 +45,12 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -45,6 +45,12 @@ struct SimpleOpTypeSetTeller : public Teller {
#endif #endif
#if IS_TRT_VERSION_GE(7130) #if IS_TRT_VERSION_GE(7130)
teller_set.insert("group_norm"); teller_set.insert("group_norm");
int8_teller_set.insert("multihead_matmul");
int8_teller_set.insert("skip_layernorm");
int8_teller_set.insert("fused_embedding_eltwise_layernorm");
int8_teller_set.insert("matmul");
int8_teller_set.insert("stack");
int8_teller_set.insert("slice");
#endif #endif
} }
......
...@@ -60,6 +60,7 @@ _out_scale_op_list = [ ...@@ -60,6 +60,7 @@ _out_scale_op_list = [
"swish", "swish",
"softmax", "softmax",
"batch_norm", "batch_norm",
"layer_norm",
"elementwise_add", "elementwise_add",
"pool2d", "pool2d",
"reshape2", "reshape2",
...@@ -67,6 +68,7 @@ _out_scale_op_list = [ ...@@ -67,6 +68,7 @@ _out_scale_op_list = [
"concat", "concat",
"elementwise_mul", "elementwise_mul",
"scale", "scale",
"slice",
"hard_swish", "hard_swish",
"hard_sigmoid", "hard_sigmoid",
"conv2d_transpose", "conv2d_transpose",
...@@ -119,6 +121,7 @@ _op_real_in_out_name = { ...@@ -119,6 +121,7 @@ _op_real_in_out_name = {
"swish": [["X"], ["Out"]], "swish": [["X"], ["Out"]],
"dropout": [["X"], ["Out"]], "dropout": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]], "batch_norm": [["X"], ["Y"]],
"layer_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Out"]], "sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]], "elementwise_mul": [["X", "Y"], ["Out"]],
"scale": [["X"], ["Out"]], "scale": [["X"], ["Out"]],
...@@ -1749,7 +1752,7 @@ class AddQuantDequantPass(object): ...@@ -1749,7 +1752,7 @@ class AddQuantDequantPass(object):
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice", "bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub", "mul", "matmul", "relu", "relu6", "squeeze", "elementwise_sub", "mul", "matmul", "relu", "relu6",
"leaky_relu", "tanh", "swish", "scale", "transpose", "transpose2", "leaky_relu", "tanh", "swish", "scale", "transpose", "transpose2",
"sigmoid", "pad2d", "flatten", "flatten2", "batch_norm" "sigmoid", "pad2d", "flatten", "flatten2", "batch_norm", "layer_norm"
] ]
# To be compatible with PaddleSlim, not remove _activation_type for now # To be compatible with PaddleSlim, not remove _activation_type for now
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册