diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index a8147fd466b5216b197f9f275ea3e7abf6ff99f9..c826e1c5a584acccd3adbc2529980cb59465926f 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -62,7 +62,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { // BOOST_GET_CONST(bool, scale->Op()->GetAttr("bias_after_scale")); // create multihead - OpDesc multihead_op_desc; + OpDesc multihead_op_desc(mul0->Op()->Block()); // create tmp tensor VarDesc k_var_desc(*mul1_out->Var()); @@ -847,7 +847,7 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, int head_number = BOOST_GET_CONST(std::vector, reshape_desc->GetAttr("shape")).at(2); - OpDesc multihead_op_desc; + OpDesc multihead_op_desc(mul0->Op()->Block()); multihead_op_desc.SetType("multihead_matmul"); multihead_op_desc.SetInput("Input", {input0->Name()}); @@ -1287,7 +1287,7 @@ int MultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, int head_number = BOOST_GET_CONST(std::vector, reshape_desc->GetAttr("shape")).at(2); - OpDesc multihead_op_desc; + OpDesc multihead_op_desc(mul0->Op()->Block()); multihead_op_desc.SetType("multihead_matmul"); multihead_op_desc.SetInput("Input", {input0->Name()}); diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index a073acc96c0d4f27e32ccf61dfe0b1414973e7cc..3935342e70296ea65943d960b51e498ec719e2af 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -23,7 +23,6 @@ class MultiheadMatMulOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { -#if IS_TRT_VERSION_GE(6000) VLOG(3) << "convert a fluid multihead_mamul op to a corresponding tensorrt " "network structure"; framework::OpDesc op_desc(op, nullptr); @@ -46,10 +45,6 @@ class MultiheadMatMulOpConverter : public OpConverter { float in_scale = 0.; if (enable_int8) { - PADDLE_ENFORCE_EQ( - op_desc.HasAttr("Input_scale"), true, - platform::errors::InvalidArgument( - "must have input scale in multihead layers in int8 mode")); in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; auto weight_scale = BOOST_GET_CONST(std::vector, op_desc.GetAttr("weight_scale")); @@ -181,10 +176,7 @@ class MultiheadMatMulOpConverter : public OpConverter { {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1}, {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, - { "var_seqlen", - &var_seqlen, - nvinfer1::PluginFieldType::kINT32, - 1 }}; + {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}}; if (qkv2context_plugin_int8) { fields.push_back( {"dq_probs", &dp_probs, nvinfer1::PluginFieldType::kFLOAT32, 1}); @@ -296,11 +288,6 @@ class MultiheadMatMulOpConverter : public OpConverter { } RreplenishLayerAndOutput(layer, "multihead_matmul", {output_name}, test_mode); -#else - PADDLE_THROW(platform::errors::Fatal( - "You are running the TRT Dynamic Shape mode, need to confirm that " - "your TRT version is no less than 6.0")); -#endif } }; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 921b455149969ceada2c594ac9922bc5d13f5dde..ea630a9c6db903ded3d68c01cca7b7028630be1f 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1085,6 +1085,42 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, VLOG(3) << "the multihead_matmul does not support static shape yet"; return false; } + + if (desc.HasAttr("enable_int8") && !desc.HasAttr("Input_scale")) { + VLOG(3) << "Multihead layers must have input scale in int8 mode."; + return false; + } + + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } + auto* input_desc = block->FindVar(desc.Input("Input").front()); + const auto input_shape = input_desc->GetShape(); + const auto head_number = + BOOST_GET_CONST(int, desc.GetAttr("head_number")); + + auto* biasqk_desc = block->FindVar(desc.Input("BiasQK").front()); + const auto biasqk_shape = biasqk_desc->GetShape(); + // The BiasQK's shape requires to be + // [batch, 1, 1, length] or [batch, head, length, length]. + bool has_same_shape = head_number == biasqk_shape[1] && + input_shape[1] == biasqk_shape[2] && + input_shape[1] == biasqk_shape[3]; + bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 && + input_shape[1] == biasqk_shape[3]; + if (!(has_same_shape || is_broadcastable)) { + VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0] + << ", 1, 1, " << input_shape[1] << "] or [" << input_shape[0] + << ", " << head_number << ", " << input_shape[1] << ", " + << input_shape[1] << "] but [" << biasqk_shape[0] << ", " + << biasqk_shape[1] << ", " << biasqk_shape[2] << ", " + << biasqk_shape[3] << "]."; + return false; + } } if (op_type == "fc") { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..e772df522b5c5096378cc7b6008dd236cdbcfd91 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py @@ -0,0 +1,436 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input1(batch, dim1): + return np.random.randn(batch, dim1, 768).astype(np.float32) + + def generate_input2(shape): + return np.random.random(shape).astype(np.float32) + + def generate_weight1(): + return np.random.randn(768, 768).astype(np.float32) + + def generate_weight2(): + return np.random.randn(768).astype(np.float32) + + for batch in [1, 2, 4]: + for reshape_shape in [[0, 0, 12, 64]]: + for dim1 in [128]: + input2_shapes = [[batch, reshape_shape[2], dim1, dim1], + [batch, 1, 1, dim1]] + for input2_shape in input2_shapes: + for axis in [0]: + dics = [{ + "x_num_col_dims": 2, + "y_num_col_dims": 1 + }, { + "axis": 2 + }, { + "shape": reshape_shape + }, { + "axis": [0, 2, 1, 3] + }, { + "x_num_col_dims": 2, + "y_num_col_dims": 1 + }, { + "axis": 2 + }, { + "shape": reshape_shape + }, { + "axis": [0, 2, 1, 3] + }, { + "x_num_col_dims": 2, + "y_num_col_dims": 1 + }, { + "axis": 2 + }, { + "shape": reshape_shape + }, { + "axis": [0, 2, 1, 3] + }, { + "scale": 0.125, + "bias": 0.0, + "bias_after_scale": True + }, { + "alpha": 1.0, + "transpose_X": False, + "transpose_Y": True, + "fused_reshape_X": [], + "fused_reshape_Y": [], + "fused_transpose_X": [], + "fused_transpose_Y": [], + "fused_reshape_Out": [], + "fused_transpose_Out": [] + }, { + "axis": axis + }, { + "axis": -1, + "is_test": True + }, { + "seed": 0, + "dropout_prob": 0.10000000149011612, + "dropout_implementation": "upscale_in_train", + "fix_seed": False, + "is_test": True + }, { + "alpha": 1.0, + "transpose_X": False, + "transpose_Y": False, + "fused_reshape_X": [], + "fused_reshape_Y": [], + "fused_transpose_X": [], + "fused_transpose_Y": [], + "fused_reshape_Out": [], + "fused_transpose_Out": [] + }, { + "axis": [0, 2, 1, 3] + }, { + "shape": [0, 0, 768] + }, { + "x_num_col_dims": 2, + "y_num_col_dims": 1 + }] + + ops_config = [ + { + "op_type": "mul", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul1_weight"] + }, + "op_outputs": { + "Out": ["mul1_output"] + }, + "op_attrs": dics[0] + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["mul1_output"], + "Y": ["elementwise_add1_weight"] + }, + "op_outputs": { + "Out": ["elementwise_add1_output"] + }, + "op_attrs": dics[1] + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_add1_output"], + }, + "op_outputs": { + "Out": ["reshape21_output"], + "XShape": ["reshape21_output_xshape"] + }, + "op_attrs": dics[2] + }, + { + "op_type": "transpose2", + "op_inputs": { + "X": ["reshape21_output"] + }, + "op_outputs": { + "Out": ["transpose21_output"], + "XShape": + ["transpose21_output_xshape"] + }, + "op_attrs": dics[3] + }, + { + "op_type": "mul", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul2_weight"] + }, + "op_outputs": { + "Out": ["mul2_output"] + }, + "op_attrs": dics[4] + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["mul2_output"], + "Y": ["elementwise_add2_weight"] + }, + "op_outputs": { + "Out": ["elementwise_add2_output"] + }, + "op_attrs": dics[5] + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_add2_output"] + }, + "op_outputs": { + "Out": ["reshape22_output"], + "XShape": ["reshape22_output_xshape"] + }, + "op_attrs": dics[6] + }, + { + "op_type": "transpose2", + "op_inputs": { + "X": ["reshape22_output"] + }, + "op_outputs": { + "Out": ["transpose22_output"], + "XShape": + ["transpose22_output_xshape"] + }, + "op_attrs": dics[7] + }, + { + "op_type": "mul", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul3_weight"] + }, + "op_outputs": { + "Out": ["mul3_output"] + }, + "op_attrs": dics[8] + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["mul3_output"], + "Y": ["elementwise_add3_weight"] + }, + "op_outputs": { + "Out": ["elementwise_add3_output"] + }, + "op_attrs": dics[9] + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_add3_output"] + }, + "op_outputs": { + "Out": ["reshape23_output"], + "XShape": ["reshape23_output_xshape"] + }, + "op_attrs": dics[10] + }, + { + "op_type": "transpose2", + "op_inputs": { + "X": ["reshape23_output"] + }, + "op_outputs": { + "Out": ["transpose23_output"], + "XShape": + ["transpose23_output_xshape"] + }, + "op_attrs": dics[11] + }, + { + "op_type": "scale", + "op_inputs": { + "X": ["transpose23_output"], + }, + "op_outputs": { + "Out": ["scale_output"] + }, + "op_attrs": dics[12] + }, + { + "op_type": "matmul", + "op_inputs": { + "X": ["scale_output"], + "Y": ["transpose22_output"], + }, + "op_outputs": { + "Out": ["matmul1_output"] + }, + "op_attrs": dics[13] + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["matmul1_output"], + "Y": ["input_data2"] + }, + "op_outputs": { + "Out": ["elementwise_add4_output"] + }, + "op_attrs": dics[14] + }, + { + "op_type": "softmax", + "op_inputs": { + "X": ["elementwise_add4_output"] + }, + "op_outputs": { + "Out": ["softmax_output"] + }, + "op_attrs": dics[15] + }, + { + "op_type": "dropout", + "op_inputs": { + "X": ["softmax_output"], + }, + "op_outputs": { + "Out": ["dropout3_output"] + }, + "op_attrs": dics[16] + }, + { + "op_type": "matmul", + "op_inputs": { + "X": ["dropout3_output"], + "Y": ["transpose21_output"], + }, + "op_outputs": { + "Out": ["matmul2_output"] + }, + "op_attrs": dics[17] + }, + { + "op_type": "transpose2", + "op_inputs": { + "X": ["matmul2_output"] + }, + "op_outputs": { + "Out": ["transpose24_output"], + "XShape": + ["transpose24_output_xshape"] + }, + "op_attrs": dics[18] + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["transpose24_output"] + }, + "op_outputs": { + "Out": ["reshape24_output"], + "XShape": ["reshape24_output_xshape"] + }, + "op_attrs": dics[19] + }, + # In order to fuse ops with + # multihead_matmul_fuse_pass_v2, the last op + # must be mul. + { + "op_type": "mul", + "op_inputs": { + "X": ["reshape24_output"], + "Y": ["mul4_weight"] + }, + "op_outputs": { + "Out": ["mul4_output"] + }, + "op_attrs": dics[20] + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "mul1_weight": TensorConfig( + data_gen=partial(generate_weight1)), + "mul2_weight": TensorConfig( + data_gen=partial(generate_weight1)), + "mul3_weight": TensorConfig( + data_gen=partial(generate_weight1)), + "mul4_weight": TensorConfig( + data_gen=partial(generate_weight1)), + "elementwise_add1_weight": TensorConfig( + data_gen=partial(generate_weight2)), + "elementwise_add2_weight": TensorConfig( + data_gen=partial(generate_weight2)), + "elementwise_add3_weight": TensorConfig( + data_gen=partial(generate_weight2)), + }, + inputs={ + "input_data1": TensorConfig( + data_gen=partial(generate_input1, batch, + dim1)), + "input_data2": TensorConfig( + data_gen=partial(generate_input2, + input2_shape)), + }, + outputs=["mul4_output"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + # The last dim of input1 and input2 should be static. + self.dynamic_shape.min_input_shape = { + "input_data1": [1, 8, 768], + "input_data2": [1, 1, 1, 128], + "reshape24_output": [1, 128, 768] + } + self.dynamic_shape.max_input_shape = { + "input_data1": [16, 512, 768], + "input_data2": [16, 256, 512, 128], + "reshape24_output": [1, 128, 768] + } + self.dynamic_shape.opt_input_shape = { + "input_data1": [8, 128, 768], + "input_data2": [8, 32, 64, 128], + "reshape24_output": [1, 128, 768] + } + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), (1, 4), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), (1, 4), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), (1, 3), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), (1, 3), 1e-5 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main()