未验证 提交 8dbfc2ae 编写于 作者: C ceci3 提交者: GitHub

[paddle-inference]support setting fully connected in multi-head attention...

[paddle-inference]support setting fully connected in multi-head attention static shape branch to int8  (#39660)

* fix inference int

* update

* add unittest
上级 28fd30cd
...@@ -335,15 +335,37 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -335,15 +335,37 @@ class MultiheadMatMulOpConverter : public OpConverter {
reshape_before_fc_dim.d[4] = 1; reshape_before_fc_dim.d[4] = 1;
auto* reshape_before_fc_layer = auto* reshape_before_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
if (enable_int8) {
engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0),
in_scale);
}
reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim);
reshape_before_fc_layer->setName( reshape_before_fc_layer->setName(
("shuffle_before_multihead_mamul(Output: " + output_name + ")") ("shuffle_before_multihead_mamul(Output: " + output_name + ")")
.c_str()); .c_str());
// add layer fc // add layer fc
auto* fc_layer = TRT_ENGINE_ADD_LAYER( nvinfer1::ILayer* fc_layer = nullptr;
engine_, FullyConnected, *reshape_before_fc_layer->getOutput(0), n, if (enable_int8) {
weight.get(), bias.get()); nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER(
engine_, Convolution, *reshape_before_fc_layer->getOutput(0), n,
nv_ksize, weight.get(), bias.get());
} else {
fc_layer = TRT_ENGINE_ADD_LAYER(
engine_, FullyConnected, *reshape_before_fc_layer->getOutput(0),
n, weight.get(), bias.get());
}
if (enable_int8) {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("fc_out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in multihead layers in int8 mode"));
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold"));
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale);
}
fc_layer->setName( fc_layer->setName(
("multihead_mamul_fc(Output: " + output_name + ")").c_str()); ("multihead_mamul_fc(Output: " + output_name + ")").c_str());
...@@ -359,6 +381,10 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -359,6 +381,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs.push_back(input_bias_qk); plugin_inputs.push_back(input_bias_qk);
bool with_fp16 = bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (enable_int8) {
with_fp16 = 1;
}
plugin::DynamicPluginTensorRT* plugin = plugin::DynamicPluginTensorRT* plugin =
new plugin::QkvToContextPluginDynamic(hidden_in, head_number, new plugin::QkvToContextPluginDynamic(hidden_in, head_number,
head_size, scale, with_fp16); head_size, scale, with_fp16);
......
...@@ -451,10 +451,394 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -451,10 +451,394 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest):
"The output has diff between gpu and trt when dynamic fp32 mode and batch size > 2." "The output has diff between gpu and trt when dynamic fp32 mode and batch size > 2."
) )
def teller3(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Int8:
return True
return False
self.add_skip_case(
teller3, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and trt in int8 mode.")
def test(self): def test(self):
self.add_skip_trt_case() self.add_skip_trt_case()
self.run_test() self.run_test()
class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
def sample_program_configs(self):
def generate_input1(batch, dim1):
return np.random.random((batch, dim1, 768)).astype(np.float32)
def generate_input2(shape):
return np.random.random(shape).astype(np.float32)
def generate_weight1():
return np.random.random((768, 768)).astype(np.float32)
def generate_weight2():
return np.random.random(768).astype(np.float32)
for batch in [1, 2, 4]:
self.batch = batch
for reshape_shape in [[0, 0, 12, 64]]:
for dim1 in [128]:
input2_shapes = [[batch, reshape_shape[2], dim1, dim1],
[batch, 1, 1, dim1]]
for input2_shape in input2_shapes:
for axis in [0]:
dics = [{
"x_num_col_dims": 2,
"y_num_col_dims": 1,
"enable_int8": True,
"X_scale": 1.0,
"weight_scale": [1.0],
}, {
"axis": 2,
"out_threshold": 1.0,
}, {
"shape": reshape_shape
}, {
"axis": [0, 2, 1, 3]
}, {
"x_num_col_dims": 2,
"y_num_col_dims": 1,
"enable_int8": True,
"X_scale": 1.0,
"weight_scale": [1.0],
}, {
"axis": 2,
"out_threshold": 1.0,
}, {
"shape": reshape_shape
}, {
"axis": [0, 2, 1, 3]
}, {
"x_num_col_dims": 2,
"y_num_col_dims": 1,
"enable_int8": True,
"X_scale": 1.0,
"weight_scale": [1.0],
}, {
"axis": 2,
"out_threshold": 1.0,
}, {
"shape": reshape_shape
}, {
"axis": [0, 2, 1, 3]
}, {
"scale": 0.125,
"bias": 0.0,
"bias_after_scale": True
}, {
"alpha": 1.0,
"transpose_X": False,
"transpose_Y": True,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": []
}, {
"axis": axis
}, {
"axis": -1,
"is_test": True
}, {
"seed": 0,
"dropout_prob": 0.10000000149011612,
"dropout_implementation": "upscale_in_train",
"fix_seed": False,
"is_test": True
}, {
"alpha": 1.0,
"transpose_X": False,
"transpose_Y": False,
"fused_reshape_X": [],
"fused_reshape_Y": [],
"fused_transpose_X": [],
"fused_transpose_Y": [],
"fused_reshape_Out": [],
"fused_transpose_Out": []
}, {
"axis": [0, 2, 1, 3]
}, {
"shape": [0, 0, 768]
}, {
"x_num_col_dims": 2,
"y_num_col_dims": 1
}]
ops_config = [
{
"op_type": "mul",
"op_inputs": {
"X": ["input_data1"],
"Y": ["mul1_weight"]
},
"op_outputs": {
"Out": ["mul1_output"]
},
"op_attrs": dics[0]
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["mul1_output"],
"Y": ["elementwise_add1_weight"]
},
"op_outputs": {
"Out": ["elementwise_add1_output"]
},
"op_attrs": dics[1]
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_add1_output"],
},
"op_outputs": {
"Out": ["reshape21_output"],
"XShape": ["reshape21_output_xshape"]
},
"op_attrs": dics[2]
},
{
"op_type": "transpose2",
"op_inputs": {
"X": ["reshape21_output"]
},
"op_outputs": {
"Out": ["transpose21_output"],
"XShape":
["transpose21_output_xshape"]
},
"op_attrs": dics[3]
},
{
"op_type": "mul",
"op_inputs": {
"X": ["input_data1"],
"Y": ["mul2_weight"]
},
"op_outputs": {
"Out": ["mul2_output"]
},
"op_attrs": dics[4]
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["mul2_output"],
"Y": ["elementwise_add2_weight"]
},
"op_outputs": {
"Out": ["elementwise_add2_output"]
},
"op_attrs": dics[5]
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_add2_output"]
},
"op_outputs": {
"Out": ["reshape22_output"],
"XShape": ["reshape22_output_xshape"]
},
"op_attrs": dics[6]
},
{
"op_type": "transpose2",
"op_inputs": {
"X": ["reshape22_output"]
},
"op_outputs": {
"Out": ["transpose22_output"],
"XShape":
["transpose22_output_xshape"]
},
"op_attrs": dics[7]
},
{
"op_type": "mul",
"op_inputs": {
"X": ["input_data1"],
"Y": ["mul3_weight"]
},
"op_outputs": {
"Out": ["mul3_output"]
},
"op_attrs": dics[8]
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["mul3_output"],
"Y": ["elementwise_add3_weight"]
},
"op_outputs": {
"Out": ["elementwise_add3_output"]
},
"op_attrs": dics[9]
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_add3_output"]
},
"op_outputs": {
"Out": ["reshape23_output"],
"XShape": ["reshape23_output_xshape"]
},
"op_attrs": dics[10]
},
{
"op_type": "transpose2",
"op_inputs": {
"X": ["reshape23_output"]
},
"op_outputs": {
"Out": ["transpose23_output"],
"XShape":
["transpose23_output_xshape"]
},
"op_attrs": dics[11]
},
{
"op_type": "scale",
"op_inputs": {
"X": ["transpose23_output"],
},
"op_outputs": {
"Out": ["scale_output"]
},
"op_attrs": dics[12]
},
{
"op_type": "matmul",
"op_inputs": {
"X": ["scale_output"],
"Y": ["transpose22_output"],
},
"op_outputs": {
"Out": ["matmul1_output"]
},
"op_attrs": dics[13]
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["matmul1_output"],
"Y": ["input_data2"]
},
"op_outputs": {
"Out": ["elementwise_add4_output"]
},
"op_attrs": dics[14]
},
{
"op_type": "softmax",
"op_inputs": {
"X": ["elementwise_add4_output"]
},
"op_outputs": {
"Out": ["softmax_output"]
},
"op_attrs": dics[15]
},
{
"op_type": "dropout",
"op_inputs": {
"X": ["softmax_output"],
},
"op_outputs": {
"Out": ["dropout3_output"]
},
"op_attrs": dics[16]
},
{
"op_type": "matmul",
"op_inputs": {
"X": ["dropout3_output"],
"Y": ["transpose21_output"],
},
"op_outputs": {
"Out": ["matmul2_output"]
},
"op_attrs": dics[17]
},
{
"op_type": "transpose2",
"op_inputs": {
"X": ["matmul2_output"]
},
"op_outputs": {
"Out": ["transpose24_output"],
"XShape":
["transpose24_output_xshape"]
},
"op_attrs": dics[18]
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["transpose24_output"]
},
"op_outputs": {
"Out": ["reshape24_output"],
"XShape": ["reshape24_output_xshape"]
},
"op_attrs": dics[19]
},
# In order to fuse ops with
# multihead_matmul_fuse_pass_v2, the last op
# must be mul.
{
"op_type": "mul",
"op_inputs": {
"X": ["reshape24_output"],
"Y": ["mul4_weight"]
},
"op_outputs": {
"Out": ["mul4_output"]
},
"op_attrs": dics[20]
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"mul1_weight": TensorConfig(
data_gen=partial(generate_weight1)),
"mul2_weight": TensorConfig(
data_gen=partial(generate_weight1)),
"mul3_weight": TensorConfig(
data_gen=partial(generate_weight1)),
"mul4_weight": TensorConfig(
data_gen=partial(generate_weight1)),
"elementwise_add1_weight": TensorConfig(
data_gen=partial(generate_weight2)),
"elementwise_add2_weight": TensorConfig(
data_gen=partial(generate_weight2)),
"elementwise_add3_weight": TensorConfig(
data_gen=partial(generate_weight2)),
},
inputs={
"input_data1": TensorConfig(
data_gen=partial(generate_input1, batch,
dim1)),
"input_data2": TensorConfig(
data_gen=partial(generate_input2,
input2_shape)),
},
outputs=["mul4_output"])
yield program_config
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册