From cf9eae4c025ddcb376800cbd1ecd7fb104814e38 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Fri, 17 Sep 2021 17:01:55 +0800 Subject: [PATCH] broadcast qkv_op (#35780) * broadcast qkv_op * use PADDLE_ENFORCE_GT to replace assert --- .../tensorrt/plugin/qkv_to_context_plugin.cu | 51 ++++++++++++++++++- .../operators/fused/multihead_matmul_op.cu | 32 +++++++++++- 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index 0d978939c4b..6bae3606afe 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -233,6 +233,24 @@ __global__ void apply_scale(T *data, T scale, int n) { #endif } +inline int round_up(int seq_len, int multiple = 32) { + PADDLE_ENFORCE_GT( + multiple, 0, + platform::errors::InvalidArgument( + "multiple should be a positive number,but it's (%d)", multiple)); + return ((seq_len + multiple - 1) / multiple) * multiple; +} + +template +__global__ void broadcast(const T *src, T *dst, const int seq_len, + const int head_num) { + int batch_id = blockIdx.x / (head_num * seq_len); + int dst_offset = blockIdx.x * seq_len; + if (threadIdx.x < seq_len) { + dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len]; + } +} + int QkvToContextPluginDynamic::enqueue( const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, @@ -258,7 +276,21 @@ int QkvToContextPluginDynamic::enqueue( auto *tptr = multihead_temp_data + scratch_size; const float *input0_data = static_cast(inputs[0]); - const float *input1_data = static_cast(inputs[1]); + // fit to [batch, head_num, length, length] + [batch, 1, 1, length] + framework::Tensor temp_qk_bias_tensor; + float *qk_bias = const_cast(static_cast(inputs[1])); + if (ProductDim(input_desc[1].dims) == (batch * seq_len)) { + temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len}); + auto *temp_qk_bias = temp_qk_bias_tensor.mutable_data( + platform::CUDAPlace(device_id)); + int grid = batch * head_number_ * seq_len; + int block = round_up(seq_len); + broadcast<<>>( + static_cast(inputs[1]), temp_qk_bias, seq_len, + head_number_); + qk_bias = temp_qk_bias; + } + const float *input1_data = static_cast(qk_bias); // BxSx3xNxH => tptr: 3xBxNxSxH. TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); @@ -290,7 +322,22 @@ int QkvToContextPluginDynamic::enqueue( half *tptr = qkptr + scratch_size; const half *input0_data = static_cast(inputs[0]); - const half *input1_data = static_cast(inputs[1]); + // fit to [batch, head_num, length, length] + [batch, 1, 1, length] + framework::Tensor temp_qk_bias_tensor; + half *qk_bias = const_cast(static_cast(inputs[1])); + if (ProductDim(input_desc[1].dims) == (batch * seq_len)) { + temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len}); + auto *temp_qk_bias = + reinterpret_cast(temp_qk_bias_tensor.mutable_data( + platform::CUDAPlace(device_id))); + int grid = batch * head_number_ * seq_len; + int block = round_up(seq_len); + broadcast<<>>( + static_cast(inputs[1]), temp_qk_bias, seq_len, + head_number_); + qk_bias = temp_qk_bias; + } + const half *input1_data = static_cast(qk_bias); // BxSx3xNxH => tptr: 3xBxNxSxH. TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index c19e621b18f..69056189ac2 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -132,6 +132,24 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size, } } +inline int round_up(int seq_len, int multiple = 32) { + PADDLE_ENFORCE_GT( + multiple, 0, + platform::errors::InvalidArgument( + "multiple should be a positive number,but it's (%d)", multiple)); + return ((seq_len + multiple - 1) / multiple) * multiple; +} + +template +__global__ void broadcast(const T *src, T *dst, const int seq_len, + const int head_num) { + int batch_id = blockIdx.x / (head_num * seq_len); + int dst_offset = blockIdx.x * seq_len; + if (threadIdx.x < seq_len) { + dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len]; + } +} + template class MultiHeadMatMulV2Kernel : public framework::OpKernel { public: @@ -152,6 +170,7 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { int head_number = context.Attr("head_number"); // compute q*k with eltadd auto &device_ctx = context.template device_context(); + auto stream = device_ctx.stream(); // should be (B * S * hidden) auto input_dims = input->dims(); // shouble be (hidden * 3 * all_head_size) @@ -159,7 +178,17 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { int batch = input_dims[0]; int seq_len = input_dims[1]; int hidden = input_dims[2]; - + Tensor temp_bias_tensor; + // if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted + if (bias_qk.numel() == (batch * seq_len)) { + temp_bias_tensor.Resize({batch * head_number * seq_len * seq_len}); + auto *temp_qk_bias = temp_bias_tensor.mutable_data(context.GetPlace()); + int grid = batch * head_number * seq_len; + int block = round_up(seq_len); + broadcast<<>>(bias_qk_d, temp_qk_bias, seq_len, + head_number); + bias_qk_d = static_cast(temp_qk_bias); + } int all_head_size = w_dims[2]; int head_size = all_head_size / head_number; @@ -196,7 +225,6 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { auto *qkptr = multihead_temp_data; auto *tptr = multihead_temp_data + scratch_size; - auto stream = device_ctx.stream(); // Do the transpose with bias. // BxSx3xNxH => tptr: 3xBxNxSxH. TransQKVWithBias(batch, seq_len, head_size, head_number, temp_out_data, -- GitLab