From f79be6567f03edf3298520780b772be7ab4ecf37 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Fri, 2 Sep 2022 18:49:18 +0800 Subject: [PATCH] padding the length of input for vit_attention (#45506) * vit_384_opt * just support trt8 * padding + unpadding * fix:unit test * refactor:padding * fix: change the position of round_up * refactor: delete workspace --- .../tensorrt/convert/multihead_matmul_op.cc | 13 +- .../tensorrt/plugin/qkv_to_context_plugin.cu | 162 ++++++- .../tensorrt/plugin/qkv_to_context_plugin.h | 5 +- .../test_trt_convert_multihead_matmul.py | 417 +++++++++--------- 4 files changed, 378 insertions(+), 219 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index 8443c92241b..40aa3940eda 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -536,8 +536,6 @@ class MultiheadMatMulOpConverter : public OpConverter { "but it's (%d) now.", input->getDimensions().nbDims)); // transpose weight_data from m * n to n * m - auto* input_bias_qk = - engine_->GetITensor(op_desc.Input("BiasQK").front()); TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), @@ -615,6 +613,17 @@ class MultiheadMatMulOpConverter : public OpConverter { std::vector plugin_inputs; plugin_inputs.push_back(fc_layer->getOutput(0)); + auto inputs = op_desc.Inputs(); + bool hasBiasQK = + (inputs.find("BiasQK") == inputs.end()) ? false : true; + nvinfer1::ITensor* input_bias_qk = nullptr; + if (hasBiasQK) { + input_bias_qk = + engine_->GetITensor(op_desc.Input("BiasQK").front()); + } else { + // fake input will be updated in qkv_plugin + input_bias_qk = fc_layer->getOutput(0); + } plugin_inputs.push_back(input_bias_qk); bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index 9602e6c8790..e4a9504d8c8 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -50,6 +50,71 @@ __global__ void transpose(T *src, threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x]; } +inline int round_up(int seq_len, int multiple = 32) { + PADDLE_ENFORCE_GT( + multiple, + 0, + platform::errors::InvalidArgument( + "multiple should be a positive number,but it's (%d)", multiple)); + return ((seq_len + multiple - 1) / multiple) * multiple; +} + +template +__global__ void reset_qk_bias(T *input, int real_seq_len, int seq_len) { + if (threadIdx.x < seq_len) { + int id = threadIdx.x + blockIdx.x * seq_len; + input[id] = threadIdx.x >= real_seq_len ? (T)-1e20f : (T)0.0f; + } +} + +template +__global__ void transpose_qkv_padding( + const T *src, // (Batch, real_seq_len, 3 , head_num * size_per_head) + T *dst, // (3 * batch * head_num * seq_len * size_per_head) + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int real_seq_len) { + // const dim3 grid(seq_len, batch, 3); + // const dim3 block(head_size, head_num, 1); + int qkv_id = blockIdx.z; + int batch_id = blockIdx.y; + int seq_id = blockIdx.x; + int head_id = threadIdx.y; + const int dst_offset = + qkv_id * batch_size * head_num * seq_len * size_per_head + + batch_id * head_num * seq_len * size_per_head + + head_id * seq_len * size_per_head + seq_id * size_per_head; + const int src_offset = + batch_id * real_seq_len * 3 * head_num * size_per_head + + seq_id * 3 * head_num * size_per_head + + qkv_id * head_num * size_per_head + head_id * size_per_head; + if (seq_id < real_seq_len) { + dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_offset]; + } else if (seq_id < seq_len) { + dst[threadIdx.x + dst_offset] = 0; + } +} + +template +__global__ void transpose_qkv_unpadding(const T *src, + T *dst, + const int batch_size, + const int seq_len, + const int head_num, + const int size_per_head, + const int real_seq_len) { + int batch_id = blockIdx.x / (head_num * real_seq_len); + int seq_id = blockIdx.x % real_seq_len; + int head_id = blockIdx.x % (head_num * real_seq_len) / real_seq_len; + dst[batch_id * head_num * real_seq_len * size_per_head + + seq_id * head_num * size_per_head + head_id * size_per_head + + threadIdx.x] = src[batch_id * head_num * seq_len * size_per_head + + head_id * seq_len * size_per_head + + seq_id * size_per_head + threadIdx.x]; +} + template __global__ void TransposeQkvKernel(const int H, const T *input, T *output) { // Input: BxSx3xNxH @@ -209,6 +274,48 @@ nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions( return ret; } +void QkvToContextPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *in, + int nb_inputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nb_outputs) TRT_NOEXCEPT { + auto input_dims = in[0].desc.dims; + int batch = input_dims.d[0]; + int real_seq_len = input_dims.d[1]; + int seq_len = round_up(real_seq_len, 8); + if (batch != -1 && real_seq_len != -1) { + int device_id = 0; + cudaGetDevice(&device_id); + auto *device_ctx = static_cast( + platform::DeviceContextPool::Instance().Get( + platform::CUDAPlace(device_id))); + const phi::GPUContext &dev_ctx = *device_ctx; + auto stream = dev_ctx.stream(); + tensor_.Resize({batch, seq_len, seq_len, head_number_}); + int blocks = batch * head_number_ * seq_len; + if (in[0].desc.type == nvinfer1::DataType::kHALF) { + mask_half_ = reinterpret_cast( + tensor_.mutable_data(platform::CUDAPlace(device_id))); + reset_qk_bias<<>>( + mask_half_, real_seq_len, seq_len); + } else if (in[0].desc.type == nvinfer1::DataType::kFLOAT) { + fake_qk_bias_ = reinterpret_cast( + tensor_.mutable_data(platform::CUDAPlace(device_id))); + long size = sizeof(int32_t) * batch * seq_len * seq_len * head_number_; +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipMemsetAsync(fake_qk_bias_, 0, size, dev_ctx.stream())); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(fake_qk_bias_, 0, size, dev_ctx.stream())); +#endif + } else { + PADDLE_THROW(platform::errors::Fatal( + "The QKV TRT Plugin's input type should be float or half.")); + } + } +} + bool QkvToContextPluginDynamic::supportsFormatCombination( int pos, const nvinfer1::PluginTensorDesc *in_out, @@ -277,15 +384,6 @@ __global__ void apply_scale(T *data, T scale, int n) { #endif } -inline int round_up(int seq_len, int multiple = 32) { - PADDLE_ENFORCE_GT( - multiple, - 0, - platform::errors::InvalidArgument( - "multiple should be a positive number,but it's (%d)", multiple)); - return ((seq_len + multiple - 1) / multiple) * multiple; -} - template __global__ void broadcast(const T *src, T *dst, @@ -342,6 +440,10 @@ int QkvToContextPluginDynamic::enqueue( head_number_); qk_bias = temp_qk_bias; } + // fake qk_bias + if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) { + qk_bias = fake_qk_bias_; + } const float *input1_data = static_cast(qk_bias); // BxSx3xNxH => tptr: 3xBxNxSxH. TransposeQKV( @@ -373,6 +475,16 @@ int QkvToContextPluginDynamic::enqueue( } else if (input_type == nvinfer1::DataType::kHALF) { #ifdef TRT_PLUGIN_FP16_AVALIABLE VLOG(1) << "TRT Plugin DataType selected. QkvToContext-->fp16"; + int real_seq_len = seq_len; + int need_padding = false; + // fake qk_bias + if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) { + seq_len = round_up(real_seq_len, 8); + scratch_size = batch * head_number_ * seq_len * seq_len * 1; + input_num = batch * seq_len * 3 * head_number_ * head_size_; + multihead_temp_tensor.Resize({scratch_size + input_num}); + need_padding = (real_seq_len != seq_len) ? true : false; + } auto *multihead_temp_data = multihead_temp_tensor.mutable_data( // NOLINT platform::CUDAPlace(device_id)); @@ -398,10 +510,27 @@ int QkvToContextPluginDynamic::enqueue( head_number_); qk_bias = temp_qk_bias; } + // padding: mask_half_ = [0,0,...-1e20f,-1e20f] + // no_padding: mask_half_ = [0,.....0,.........,0] + if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) { + qk_bias = mask_half_; + } const half *input1_data = static_cast(qk_bias); // BxSx3xNxH => tptr: 3xBxNxSxH. - TransposeQKV( - batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); + if (need_padding) { + dim3 grid_p(seq_len, batch, 3); + dim3 block_p(head_size_, head_number_, 1); + transpose_qkv_padding<<>>(input0_data, + tptr, + batch, + seq_len, + head_number_, + head_size_, + real_seq_len); + } else { + TransposeQKV( + batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); + } auto *device_ctx = static_cast( platform::DeviceContextPool::Instance().Get( @@ -430,8 +559,15 @@ int QkvToContextPluginDynamic::enqueue( int grid = batch * head_number_ * seq_len; int block = head_size_; half *output = static_cast(outputs[0]); - transpose<<>>( - tptr, output, batch, seq_len, head_number_, head_size_); + if (need_padding) { + int grid_u = batch * head_number_ * real_seq_len; + int block_u = head_size_; + transpose_qkv_unpadding<<>>( + tptr, output, batch, seq_len, head_number_, head_size_, real_seq_len); + } else { + transpose<<>>( + tptr, output, batch, seq_len, head_number_, head_size_); + } #else PADDLE_THROW(platform::errors::Fatal( "The Ernie(Bert) TensorRT Plugin should be " diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h index 650681a7de8..17c9e904d42 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h @@ -97,7 +97,7 @@ class QkvToContextPluginDynamic : public DynamicPluginTensorRT { void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nb_inputs, const nvinfer1::DynamicPluginTensorDesc* out, - int nb_outputs) TRT_NOEXCEPT override {} + int nb_outputs) TRT_NOEXCEPT override; size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nb_inputs, @@ -124,6 +124,9 @@ class QkvToContextPluginDynamic : public DynamicPluginTensorRT { int head_number_; int head_size_; float scale_; + framework::Tensor tensor_; + half* mask_half_; + float* fake_qk_bias_; }; class QkvToContextPluginDynamicCreator : public nvinfer1::IPluginCreator { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py index cf8611e4d8b..d552692ae4f 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py @@ -424,6 +424,7 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest): # for static_shape clear_dynamic_shape() self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.trt_param.workspace_size = 2013265920 yield self.create_inference_config(), (1, 4), (1e-5, 1e-5) self.trt_param.precision = paddle_infer.PrecisionType.Half yield self.create_inference_config(), (1, 4), (1e-5, 1e-5) @@ -431,6 +432,7 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest): # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 + self.trt_param.workspace_size = 2013265920 yield self.create_inference_config(), (1, 3), (1e-5, 1e-4) self.trt_param.precision = paddle_infer.PrecisionType.Half yield self.create_inference_config(), (1, 3), (1e-5, 1e-5) @@ -855,8 +857,8 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): def sample_program_configs(self): - def generate_input1(): - return np.zeros((1, 256, 768), dtype=np.float32) + def generate_input1(batch, length): + return np.zeros((batch, length, 768), dtype=np.float32) def generate_weight1(): return np.random.rand(768, 2304).astype(np.float32) @@ -864,210 +866,216 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): def generate_weight2(): return np.random.rand(2304).astype(np.float32) - ops_config = [{ - "op_type": "matmul_v2", - "op_inputs": { - "X": ["input_data1"], - "Y": ["matmul1_weight"] - }, - "op_outputs": { - "Out": ["matmul1_output"] - }, - "op_attrs": { - "trans_x": False, - "trans_y": False - } - }, { - "op_type": "elementwise_add", - "op_inputs": { - "X": ["matmul1_output"], - "Y": ["elementwise_add1_weight"] - }, - "op_outputs": { - "Out": ["elementwise_add1_output"] - }, - "op_attrs": { - "Scale_out": 1.0, - "Scale_x": 1.0, - "Scale_y": 1.0, - "axis": 2 - } - }, { - "op_type": "reshape2", - "op_inputs": { - "X": ["elementwise_add1_output"], - }, - "op_outputs": { - "Out": ["reshape1_output"], - "XShape": ["reshape1_output_xshape"] - }, - "op_attrs": { - "shape": [-1, 256, 3, 12, 64] - } - }, { - "op_type": "transpose2", - "op_inputs": { - "X": ["reshape1_output"] - }, - "op_outputs": { - "Out": ["transpose1_output"], - "XShape": ["transpose1_output_xshape"] - }, - "op_attrs": { - "axis": [2, 0, 3, 1, 4], - "data_format": "AnyLayout" - } - }, { - "op_type": "slice", - "op_inputs": { - "Input": ["transpose1_output"], - }, - "op_outputs": { - "Out": ["slice1_output"] - }, - "op_attrs": { - "axes": [0], - "starts": [0], - "ends": [1], - "decrease_axis": [0], - "infer_flags": [1] - } - }, { - "op_type": "slice", - "op_inputs": { - "Input": ["transpose1_output"], - }, - "op_outputs": { - "Out": ["slice2_output"] - }, - "op_attrs": { - "axes": [0], - "starts": [1], - "ends": [2], - "decrease_axis": [0], - "infer_flags": [1] - } - }, { - "op_type": "slice", - "op_inputs": { - "Input": ["transpose1_output"], - }, - "op_outputs": { - "Out": ["slice3_output"] - }, - "op_attrs": { - "axes": [0], - "starts": [2], - "ends": [3], - "decrease_axis": [0], - "infer_flags": [1] - } - }, { - "op_type": "transpose2", - "op_inputs": { - "X": ["slice2_output"] - }, - "op_outputs": { - "Out": ["transpose2_output"], - }, - "op_attrs": { - "axis": [0, 1, 3, 2], - "data_format": "AnyLayout" - } - }, { - "op_type": "matmul_v2", - "op_inputs": { - "X": ["slice1_output"], - "Y": ["transpose2_output"] - }, - "op_outputs": { - "Out": ["matmul2_output"] - }, - "op_attrs": { - "trans_x": False, - "trans_y": False - } - }, { - "op_type": "scale", - "op_inputs": { - "X": ["matmul2_output"], - }, - "op_outputs": { - "Out": ["scale_output"] - }, - "op_attrs": { - "scale": 0.125, - "bias": 0.0, - "bias_after_scale": True - } - }, { - "op_type": "softmax", - "op_inputs": { - "X": ["scale_output"] - }, - "op_outputs": { - "Out": ["softmax_output"] - }, - "op_attrs": { - "axis": -1, - "data_format": "AnyLayout" - } - }, { - "op_type": "matmul_v2", - "op_inputs": { - "X": ["softmax_output"], - "Y": ["slice3_output"] - }, - "op_outputs": { - "Out": ["matmul3_output"] - }, - "op_attrs": { - "trans_x": False, - "trans_y": False - } - }, { - "op_type": "transpose2", - "op_inputs": { - "X": ["matmul3_output"] - }, - "op_outputs": { - "Out": ["transpose3_output"], - "XShape": ["transpose3_output_xshape"] - }, - "op_attrs": { - "axis": [0, 2, 1, 3], - "data_format": "AnyLayout" - } - }, { - "op_type": "reshape2", - "op_inputs": { - "X": ["transpose3_output"] - }, - "op_outputs": { - "Out": ["reshape2_output"], - "XShape": ["reshape2_output_xshape"] - }, - "op_attrs": { - "shape": [-1, 256, 768] - } - }] + for batch in [2, 4]: + self.batch = batch + for length in [64, 384]: + self.length = length + ops_config = [{ + "op_type": "matmul_v2", + "op_inputs": { + "X": ["input_data1"], + "Y": ["matmul1_weight"] + }, + "op_outputs": { + "Out": ["matmul1_output"] + }, + "op_attrs": { + "trans_x": False, + "trans_y": False + } + }, { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["matmul1_output"], + "Y": ["elementwise_add1_weight"] + }, + "op_outputs": { + "Out": ["elementwise_add1_output"] + }, + "op_attrs": { + "Scale_out": 1.0, + "Scale_x": 1.0, + "Scale_y": 1.0, + "axis": 2 + } + }, { + "op_type": "reshape2", + "op_inputs": { + "X": ["elementwise_add1_output"], + }, + "op_outputs": { + "Out": ["reshape1_output"], + "XShape": ["reshape1_output_xshape"] + }, + "op_attrs": { + "shape": [-1, self.length, 3, 12, 64] + } + }, { + "op_type": "transpose2", + "op_inputs": { + "X": ["reshape1_output"] + }, + "op_outputs": { + "Out": ["transpose1_output"], + "XShape": ["transpose1_output_xshape"] + }, + "op_attrs": { + "axis": [2, 0, 3, 1, 4], + "data_format": "AnyLayout" + } + }, { + "op_type": "slice", + "op_inputs": { + "Input": ["transpose1_output"], + }, + "op_outputs": { + "Out": ["slice1_output"] + }, + "op_attrs": { + "axes": [0], + "starts": [0], + "ends": [1], + "decrease_axis": [0], + "infer_flags": [1] + } + }, { + "op_type": "slice", + "op_inputs": { + "Input": ["transpose1_output"], + }, + "op_outputs": { + "Out": ["slice2_output"] + }, + "op_attrs": { + "axes": [0], + "starts": [1], + "ends": [2], + "decrease_axis": [0], + "infer_flags": [1] + } + }, { + "op_type": "slice", + "op_inputs": { + "Input": ["transpose1_output"], + }, + "op_outputs": { + "Out": ["slice3_output"] + }, + "op_attrs": { + "axes": [0], + "starts": [2], + "ends": [3], + "decrease_axis": [0], + "infer_flags": [1] + } + }, { + "op_type": "transpose2", + "op_inputs": { + "X": ["slice2_output"] + }, + "op_outputs": { + "Out": ["transpose2_output"], + }, + "op_attrs": { + "axis": [0, 1, 3, 2], + "data_format": "AnyLayout" + } + }, { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["slice1_output"], + "Y": ["transpose2_output"] + }, + "op_outputs": { + "Out": ["matmul2_output"] + }, + "op_attrs": { + "trans_x": False, + "trans_y": False + } + }, { + "op_type": "scale", + "op_inputs": { + "X": ["matmul2_output"], + }, + "op_outputs": { + "Out": ["scale_output"] + }, + "op_attrs": { + "scale": 0.125, + "bias": 0.0, + "bias_after_scale": True + } + }, { + "op_type": "softmax", + "op_inputs": { + "X": ["scale_output"] + }, + "op_outputs": { + "Out": ["softmax_output"] + }, + "op_attrs": { + "axis": -1, + "data_format": "AnyLayout" + } + }, { + "op_type": "matmul_v2", + "op_inputs": { + "X": ["softmax_output"], + "Y": ["slice3_output"] + }, + "op_outputs": { + "Out": ["matmul3_output"] + }, + "op_attrs": { + "trans_x": False, + "trans_y": False + } + }, { + "op_type": "transpose2", + "op_inputs": { + "X": ["matmul3_output"] + }, + "op_outputs": { + "Out": ["transpose3_output"], + "XShape": ["transpose3_output_xshape"] + }, + "op_attrs": { + "axis": [0, 2, 1, 3], + "data_format": "AnyLayout" + } + }, { + "op_type": "reshape2", + "op_inputs": { + "X": ["transpose3_output"] + }, + "op_outputs": { + "Out": ["reshape2_output"], + "XShape": ["reshape2_output_xshape"] + }, + "op_attrs": { + "shape": [-1, self.length, 768] + } + }] - ops = self.generate_op_config(ops_config) + ops = self.generate_op_config(ops_config) - program_config = ProgramConfig( - ops=ops, - weights={ - "matmul1_weight": - TensorConfig(data_gen=partial(generate_weight1)), - "elementwise_add1_weight": - TensorConfig(data_gen=partial(generate_weight2)) - }, - inputs={ - "input_data1": TensorConfig(data_gen=partial(generate_input1)) - }, - outputs=["reshape2_output"]) + program_config = ProgramConfig( + ops=ops, + weights={ + "matmul1_weight": + TensorConfig(data_gen=partial(generate_weight1)), + "elementwise_add1_weight": + TensorConfig(data_gen=partial(generate_weight2)) + }, + inputs={ + "input_data1": + TensorConfig( + data_gen=partial(generate_input1, batch, length)) + }, + outputs=["reshape2_output"]) - yield program_config + yield program_config def sample_predictor_configs( self, program_config) -> (paddle_infer.Config, List[int], float): @@ -1105,6 +1113,9 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): self.trt_param.precision = paddle_infer.PrecisionType.Half yield self.create_inference_config(), generate_trt_nodes_num(), (1e-3, 1e-3) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(), (1e-5, + 1e-5) def add_skip_trt_case(self): pass -- GitLab