diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 363b3132a1536b64862433fd5edfc51d483612fc..22bd172e93b40f773a3e211d32e939f73f2a40a7 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1744,13 +1744,19 @@ struct SimpleOpTypeSetTeller : public Teller { input_shape[1] == biasqk_shape[3]; bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 && input_shape[1] == biasqk_shape[3]; + is_broadcastable = + is_broadcastable || (biasqk_shape[0] == 1 && biasqk_shape[1] == 1 && + input_shape[1] == biasqk_shape[2] && + input_shape[1] == biasqk_shape[3]); if (!(has_same_shape || is_broadcastable)) { VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0] - << ", 1, 1, " << input_shape[1] << "] or [" << input_shape[0] - << ", " << head_number << ", " << input_shape[1] << ", " - << input_shape[1] << "] but [" << biasqk_shape[0] << ", " - << biasqk_shape[1] << ", " << biasqk_shape[2] << ", " - << biasqk_shape[3] << "]."; + << ", 1, 1, " << input_shape[1] << "] " + << "or [" << input_shape[0] << ", " << head_number << ", " + << input_shape[1] << ", " << input_shape[1] << "] " + << "or [" << input_shape[0] << "/1, " << 1 << ", " + << input_shape[1] << ", " << input_shape[1] << "] " + << "but got [" << biasqk_shape[0] << ", " << biasqk_shape[1] + << ", " << biasqk_shape[2] << ", " << biasqk_shape[3] << "]."; return false; } } else { diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index 5e3f078cf9f4d586501d10ec34be8ac25ea8868a..731441463df7ebaaf7d05e84a956bf76577cc939 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -309,6 +309,19 @@ __global__ void broadcast(const T *src, } } +template +__global__ void broadcast_batch_head_number(const T *src, + T *dst, + const int batch_size, + const int seq_len, + const int head_num) { + int batch_id = blockIdx.x % seq_len; + int dst_offset = blockIdx.x * seq_len; + if (threadIdx.x < seq_len) { + dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len]; + } +} + int QkvToContextPluginDynamic::enqueue( const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, @@ -353,6 +366,22 @@ int QkvToContextPluginDynamic::enqueue( head_number_); qk_bias = temp_qk_bias; } + // fit to [batch, head_num, length, length] + [1, 1, length, length] + if (ProductDim(input_desc[1].dims) == (seq_len * seq_len)) { + temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len}); + auto *temp_qk_bias = + reinterpret_cast(temp_qk_bias_tensor.mutable_data( + platform::CUDAPlace(device_id))); + int grid = batch * head_number_ * seq_len; + int block = round_up(seq_len); + broadcast_batch_head_number<<>>( + static_cast(inputs[1]), + temp_qk_bias, + batch, + seq_len, + head_number_); + qk_bias = temp_qk_bias; + } // fake qk_bias if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) { qk_bias = fake_qk_bias_; @@ -424,6 +453,22 @@ int QkvToContextPluginDynamic::enqueue( head_number_); qk_bias = temp_qk_bias; } + // fit to [batch, head_num, length, length] + [1, 1, length, length] + if (ProductDim(input_desc[1].dims) == (seq_len * seq_len)) { + temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len}); + auto *temp_qk_bias = + reinterpret_cast(temp_qk_bias_tensor.mutable_data( + platform::CUDAPlace(device_id))); + int grid = batch * head_number_ * seq_len; + int block = round_up(seq_len); + broadcast_batch_head_number<<>>( + static_cast(inputs[1]), + temp_qk_bias, + batch, + seq_len, + head_number_); + qk_bias = temp_qk_bias; + } // padding: mask_half_ = [1.0,....1.0...1.0....,0.0f] // no_padding: mask_half_ = [1.0,....1.0,.........,1.0f] bool bias_is_mask = false; diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index f1deedce5f133a4457a321e48db72dc6ddc96926..2e8b6f7d0a6b8a67d7cc1036bf3f83b61745c30f 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -256,6 +256,19 @@ __global__ void broadcast(const T *src, } } +template +__global__ void broadcast_batch_head_number(const T *src, + T *dst, + const int batch_size, + const int seq_len, + const int head_num) { + int src_seq_id = blockIdx.x % seq_len; + int dst_offset = blockIdx.x * seq_len; + if (threadIdx.x < seq_len) { + dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_seq_id * seq_len]; + } +} + template class MultiHeadMatMulV2Kernel : public framework::OpKernel { public: @@ -286,6 +299,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { Tensor temp_bias_tensor; // if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted if (bias_qk && bias_qk->numel() == (batch * seq_len)) { + VLOG(4) << "Do broadcasted bias_qk from [batch, 1, 1, seq_len]"; temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); auto *temp_qk_bias = device_ctx.template Alloc( &temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T)); @@ -295,6 +309,19 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { bias_qk_d, temp_qk_bias, seq_len, head_number); bias_qk_d = static_cast(temp_qk_bias); } + // if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be + // broadcasted + if (bias_qk && bias_qk->numel() == (1 * seq_len * seq_len)) { + VLOG(4) << "do broadcasted bias_qk from [1, 1, seq_len, seq_len]"; + temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); + auto *temp_qk_bias = device_ctx.template Alloc( + &temp_bias_tensor, temp_bias_tensor.numel() * sizeof(T)); + int grid = batch * head_number * seq_len; + int block = round_up(seq_len); + broadcast_batch_head_number<<>>( + bias_qk_d, temp_qk_bias, batch, seq_len, head_number); + bias_qk_d = static_cast(temp_qk_bias); + } if (!bias_qk) { int size = batch * head_number * seq_len * seq_len; temp_bias_tensor.Resize({size}); @@ -333,7 +360,8 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { // (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H) auto blas = phi::funcs::GetBlas(device_ctx); blas.MatMul(input_matrix, w_matrix, &temp_out_tensor); - + VLOG(2) << "(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)"; + VLOG(2) << temp_out_tensor; // temp_out_tensor.Resize(temp_out_dims); Tensor multihead_temp_tensor; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py index 9dd7ae4a8f432509fe4d4be09f9bbb7f0afb67f5..074b55d5df1ad6a9866e914f1f9372a969b483e0 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py @@ -1081,5 +1081,419 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): self.run_test() +class TrtConvertMultiHeadMatmulTest_biasqk_seqseq(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input1(batch, dim1): + return np.random.random((batch, dim1, 768)).astype(np.float32) + + def generate_input2(shape): + return np.random.random(shape).astype(np.float32) + + def generate_weight1(): + return np.random.random((768, 768)).astype(np.float32) + + def generate_weight2(): + return np.random.random(768).astype(np.float32) + + def generate_weight3(): + return np.random.random((768, 768)).astype(np.float32) + + for batch in [2]: + self.batch = batch + for reshape_shape in [[0, 0, 12, 64]]: + for dim1 in [128]: + input2_shapes = [ + [batch, reshape_shape[2], dim1, dim1], + [batch, 1, 1, dim1], + ] + for input2_shape in input2_shapes: + for axis in [0]: + dics = [ + {"x_num_col_dims": 2, "y_num_col_dims": 1}, + {"axis": 2}, + {"shape": reshape_shape}, + {"axis": [0, 2, 1, 3]}, + {"x_num_col_dims": 2, "y_num_col_dims": 1}, + {"axis": 2}, + {"shape": reshape_shape}, + {"axis": [0, 2, 1, 3]}, + {"x_num_col_dims": 2, "y_num_col_dims": 1}, + {"axis": 2}, + {"shape": reshape_shape}, + {"axis": [0, 2, 1, 3]}, + { + "scale": 0.125, + "bias": 0.0, + "bias_after_scale": True, + }, + { + "alpha": 1.0, + "transpose_X": False, + "transpose_Y": True, + "fused_reshape_X": [], + "fused_reshape_Y": [], + "fused_transpose_X": [], + "fused_transpose_Y": [], + "fused_reshape_Out": [], + "fused_transpose_Out": [], + }, + {"axis": axis}, + {"axis": -1, "is_test": True}, + { + "seed": 0, + "dropout_prob": 0.10000000149011612, + "dropout_implementation": "upscale_in_train", + "fix_seed": False, + "is_test": True, + }, + { + "alpha": 1.0, + "transpose_X": False, + "transpose_Y": False, + "fused_reshape_X": [], + "fused_reshape_Y": [], + "fused_transpose_X": [], + "fused_transpose_Y": [], + "fused_reshape_Out": [], + "fused_transpose_Out": [], + }, + {"axis": [0, 2, 1, 3]}, + {"shape": [0, 0, 768]}, + {"x_num_col_dims": 2, "y_num_col_dims": 1}, + ] + + ops_config = [ + { + "op_type": "mul", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul1_weight"], + }, + "op_outputs": {"Out": ["mul1_output"]}, + "op_attrs": dics[0], + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["mul1_output"], + "Y": ["elementwise_add1_weight"], + }, + "op_outputs": { + "Out": ["elementwise_add1_output"] + }, + "op_attrs": dics[1], + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_add1_output"], + }, + "op_outputs": { + "Out": ["reshape21_output"], + "XShape": ["reshape21_output_xshape"], + }, + "op_attrs": dics[2], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape21_output"]}, + "op_outputs": { + "Out": ["transpose21_output"], + "XShape": ["transpose21_output_xshape"], + }, + "op_attrs": dics[3], + }, + { + "op_type": "mul", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul2_weight"], + }, + "op_outputs": {"Out": ["mul2_output"]}, + "op_attrs": dics[4], + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["mul2_output"], + "Y": ["elementwise_add2_weight"], + }, + "op_outputs": { + "Out": ["elementwise_add2_output"] + }, + "op_attrs": dics[5], + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_add2_output"] + }, + "op_outputs": { + "Out": ["reshape22_output"], + "XShape": ["reshape22_output_xshape"], + }, + "op_attrs": dics[6], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape22_output"]}, + "op_outputs": { + "Out": ["transpose22_output"], + "XShape": ["transpose22_output_xshape"], + }, + "op_attrs": dics[7], + }, + { + "op_type": "mul", + "op_inputs": { + "X": ["input_data1"], + "Y": ["mul3_weight"], + }, + "op_outputs": {"Out": ["mul3_output"]}, + "op_attrs": dics[8], + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["mul3_output"], + "Y": ["elementwise_add3_weight"], + }, + "op_outputs": { + "Out": ["elementwise_add3_output"] + }, + "op_attrs": dics[9], + }, + { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_add3_output"] + }, + "op_outputs": { + "Out": ["reshape23_output"], + "XShape": ["reshape23_output_xshape"], + }, + "op_attrs": dics[10], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["reshape23_output"]}, + "op_outputs": { + "Out": ["transpose23_output"], + "XShape": ["transpose23_output_xshape"], + }, + "op_attrs": dics[11], + }, + { + "op_type": "scale", + "op_inputs": { + "X": ["transpose23_output"], + }, + "op_outputs": {"Out": ["scale_output"]}, + "op_attrs": dics[12], + }, + { + "op_type": "matmul", + "op_inputs": { + "X": ["scale_output"], + "Y": ["transpose22_output"], + }, + "op_outputs": {"Out": ["matmul1_output"]}, + "op_attrs": dics[13], + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["matmul1_output"], + "Y": ["input_data2"], + }, + "op_outputs": { + "Out": ["elementwise_add4_output"] + }, + "op_attrs": dics[14], + }, + { + "op_type": "softmax", + "op_inputs": { + "X": ["elementwise_add4_output"] + }, + "op_outputs": {"Out": ["softmax_output"]}, + "op_attrs": dics[15], + }, + { + "op_type": "dropout", + "op_inputs": { + "X": ["softmax_output"], + }, + "op_outputs": {"Out": ["dropout3_output"]}, + "op_attrs": dics[16], + }, + { + "op_type": "matmul", + "op_inputs": { + "X": ["dropout3_output"], + "Y": ["transpose21_output"], + }, + "op_outputs": {"Out": ["matmul2_output"]}, + "op_attrs": dics[17], + }, + { + "op_type": "transpose2", + "op_inputs": {"X": ["matmul2_output"]}, + "op_outputs": { + "Out": ["transpose24_output"], + "XShape": ["transpose24_output_xshape"], + }, + "op_attrs": dics[18], + }, + { + "op_type": "reshape2", + "op_inputs": {"X": ["transpose24_output"]}, + "op_outputs": { + "Out": ["reshape24_output"], + "XShape": ["reshape24_output_xshape"], + }, + "op_attrs": dics[19], + }, + # In order to fuse ops with + # multihead_matmul_fuse_pass_v2, the last op + # must be mul. + { + "op_type": "mul", + "op_inputs": { + "X": ["reshape24_output"], + "Y": ["mul4_weight"], + }, + "op_outputs": {"Out": ["mul4_output"]}, + "op_attrs": dics[20], + }, + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "mul1_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + "mul2_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + "mul3_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + "mul4_weight": TensorConfig( + data_gen=partial(generate_weight1) + ), + "elementwise_add1_weight": TensorConfig( + data_gen=partial(generate_weight2) + ), + "elementwise_add2_weight": TensorConfig( + data_gen=partial(generate_weight3) + ), + "elementwise_add3_weight": TensorConfig( + data_gen=partial(generate_weight2) + ), + }, + inputs={ + "input_data1": TensorConfig( + data_gen=partial( + generate_input1, batch, dim1 + ) + ), + "input_data2": TensorConfig( + data_gen=partial( + generate_input2, input2_shape + ) + ), + }, + outputs=["mul4_output"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + # The last dim of input1 and input2 should be static. + self.dynamic_shape.min_input_shape = { + "input_data1": [1, 8, 768], + "input_data2": [1, 1, 1, 128], + "reshape24_output": [1, 128, 768], + } + self.dynamic_shape.max_input_shape = { + "input_data1": [16, 512, 768], + "input_data2": [16, 256, 512, 128], + "reshape24_output": [1, 128, 768], + } + self.dynamic_shape.opt_input_shape = { + "input_data1": [8, 128, 768], + "input_data2": [8, 32, 64, 128], + "reshape24_output": [1, 128, 768], + } + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.trt_param.workspace_size = 2013265920 + yield self.create_inference_config(), (1, 3), (1e-5, 1e-4) + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), (1, 3), (1e-3, 1e-3) + + def add_skip_trt_case(self): + def teller1(program_config, predictor_config): + if self.trt_param.precision == paddle_infer.PrecisionType.Half: + return True + return False + + self.add_skip_case( + teller1, + SkipReasons.TRT_NOT_IMPLEMENTED, + "The output has diff between gpu and trt in fp16 mode.", + ) + + def teller2(program_config, predictor_config): + if ( + self.trt_param.precision == paddle_infer.PrecisionType.Float32 + and len(self.dynamic_shape.min_input_shape) != 0 + and self.batch > 2 + ): + return True + return False + + self.add_skip_case( + teller2, + SkipReasons.TRT_NOT_IMPLEMENTED, + "The output has diff between gpu and trt when dynamic fp32 mode and batch size > 2.", + ) + + def teller3(program_config, predictor_config): + if self.trt_param.precision == paddle_infer.PrecisionType.Int8: + return True + return False + + self.add_skip_case( + teller3, + SkipReasons.TRT_NOT_IMPLEMENTED, + "The output has diff between gpu and trt in int8 mode.", + ) + + def test(self): + self.add_skip_trt_case() + self.run_test() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py index e2b53903b6d72a8de4568673c42ad2417ab0f669..55c2a563c8cdf9c223d0595e9ac067ec11011a88 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py @@ -29,6 +29,113 @@ def stable_softmax(x): return exps / np.sum(exps) +@unittest.skipIf( + not core.is_compiled_with_cuda(), "Paddle core is not compiled with CUDA" +) +class TestFusedMultiHeadMatmulOp_biasqk2(OpTest): + def config(self): + self.seq_len = 128 + self.size_per_head = 64 + self.head_number = 12 + self.batch_size = 8 + self.scale = 0.125 + + def setUp(self): + self.op_type = "multihead_matmul" + self.config() + h = self.seq_len + w = self.head_number * self.size_per_head + self.Input = ( + np.random.random((self.batch_size, h, w)).astype("float32") - 0.5 + ) + self.WQ = np.random.random((w, w)).astype("float32") + self.KQ = np.random.random((w, w)).astype("float32") + self.VQ = np.random.random((w, w)).astype("float32") + self.CombinedW = np.hstack((self.WQ, self.KQ, self.VQ)).reshape( + (w, 3, w) + ) + self.Q = np.dot(self.Input, self.WQ) + self.K = np.dot(self.Input, self.KQ) + self.V = np.dot(self.Input, self.VQ) + + self.BiasQ = np.random.random((1, w)).astype("float32") + self.BiasK = np.random.random((1, w)).astype("float32") + self.BiasV = np.random.random((1, w)).astype("float32") + self.CombinedB = np.vstack((self.BiasQ, self.BiasK, self.BiasV)) + self.BiasQK = np.random.random( + (1, 1, self.seq_len, self.seq_len) + ).astype("float32") + # Compute Q path + fc_q = self.Q + self.BiasQ + reshape_q = np.reshape( + fc_q, + ( + self.batch_size, + self.seq_len, + self.head_number, + self.size_per_head, + ), + ) + transpose_q = np.transpose(reshape_q, (0, 2, 1, 3)) + scale_q = self.scale * transpose_q + # Compute K path + fc_k = self.K + self.BiasK + reshape_k = np.reshape( + fc_k, + ( + self.batch_size, + self.seq_len, + self.head_number, + self.size_per_head, + ), + ) + transpose_k = np.transpose(reshape_k, (0, 2, 3, 1)) + + # Compute Q*K + q_k = np.matmul(scale_q, transpose_k) + eltadd_qk = q_k + np.tile( + self.BiasQK, [self.batch_size, self.head_number, 1, 1] + ) + softmax_qk = np.apply_along_axis(stable_softmax, 3, eltadd_qk) + # Compute V path + fc_v = self.V + self.BiasV + reshape_v = np.reshape( + fc_v, + ( + self.batch_size, + self.seq_len, + self.head_number, + self.size_per_head, + ), + ) + transpose_v = np.transpose(reshape_v, (0, 2, 1, 3)) + + # Compute QK*V + qkv = np.matmul(softmax_qk, transpose_v) + transpose_qkv = np.transpose(qkv, (0, 2, 1, 3)) + reshape_qkv = np.reshape(transpose_qkv, (self.batch_size, h, w)) + print("biasqk shape") + print(self.BiasQK.shape) + self.inputs = { + "Input": self.Input, + "W": self.CombinedW, + "Bias": self.CombinedB, + "BiasQK": self.BiasQK, + } + self.attrs = { + "transpose_Q": False, + "transpose_K": True, + "transpose_V": False, + "head_number": self.head_number, + "alpha": self.scale, + } + self.outputs = {"Out": reshape_qkv} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=2e-3) + + @unittest.skipIf( not core.is_compiled_with_cuda(), "Paddle core is not compiled with CUDA" )