From e5cf75d848f6b1d38cd6a7353c3f8598c6d3ef88 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 1 Dec 2022 19:57:37 +0800 Subject: [PATCH] [Paddle Inference] General optimization for no_varlen multihead (#48469) * general optimization for no_varlen multihead --- .../ir/remove_padding_recover_padding_pass.cc | 3 - ...t_embedding_eltwise_layernorm_fuse_pass.cc | 4 +- .../ir/trt_multihead_matmul_fuse_pass.cc | 4 +- .../ir/trt_skip_layernorm_fuse_pass.cc | 4 +- .../fluid/inference/tensorrt/convert/fc_op.cc | 2 +- .../tensorrt/convert/multihead_matmul_op.cc | 123 +++++- .../convert/transformer_input_convert_op.cc | 2 +- .../inference/tensorrt/plugin/CMakeLists.txt | 2 +- .../tensorrt/plugin/remove_padding_plugin.cu | 1 - .../transformer_input_convert_plugin.cu | 122 ------ ...transformer_input_output_convert_plugin.cu | 356 ++++++++++++++++++ ...transformer_input_output_convert_plugin.h} | 117 +++++- 12 files changed, 595 insertions(+), 145 deletions(-) delete mode 100644 paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.cu rename paddle/fluid/inference/tensorrt/plugin/{transformer_input_convert_plugin.h => transformer_input_output_convert_plugin.h} (53%) diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc index 5127c5934c..19c2e0541b 100644 --- a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc @@ -439,9 +439,6 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { "remove_padding pass."; return; } - fc_op->Op()->RemoveAttr("in_num_col_dims"); - fc_op->Op()->SetAttr("in_num_col_dims", 1); - insert_remove_padding_op(fc_input, fc_op); insert_recover_padding_op(fc_op, fc_op->outputs[0]); found_subgraph_count++; diff --git a/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc index f870796a4c..23ebbddf57 100644 --- a/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc @@ -441,14 +441,14 @@ void TrtEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { std::string mask_id = Get("tensorrt_transformer_maskid"); if ((use_varseqlen && pos_id != "" && mask_id != "") || - (!use_varseqlen && pos_id == "" && mask_id == "")) { + (!use_varseqlen && pos_id == "")) { VLOG(3) << "start trt_embedding_eltwise_layernorm_fuse_pass"; } else { PADDLE_THROW( platform::errors::Fatal("Use transformer'varseqlen need config: " "use_varseqlen, set pos_id, set " "mask_id. Or not use varseqlen, do not set " - "pos_id, set mask_id. Please " + "pos_id. Please " "reconfig")); } graph->Set(kEmbEltwiseLayernormPass, new bool(true)); diff --git a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc index 4ecc9919f5..1d17cba445 100644 --- a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc @@ -1637,14 +1637,14 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const { "preln_embedding_eltwise_layernorm_fuse_" "pass. please use no_varseqlen")); } - } else if (!use_varseqlen && pos_id == "" && mask_id == "") { + } else if (!use_varseqlen && pos_id == "") { VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass"; } else { PADDLE_THROW( platform::errors::Fatal("Use transformer'varseqlen need config: " "use_varseqlen, set pos_id, set " "mask_id. Or not use varseqlen, do not set " - "pos_id, set mask_id. Please " + "pos_id. Please " "reconfig")); } graph->Set(kMultiheadMatmulPass, new bool(true)); diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index d33adab8b3..2e578a06e3 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -207,14 +207,14 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { "trt_embedding_eltwise_layernorm_fuse_pass, " "trt_multihead_matmul_fuse_pass. please use no_varseqlen")); } - } else if (!use_varseqlen && pos_id == "" && mask_id == "") { + } else if (!use_varseqlen && pos_id == "") { VLOG(3) << "start no_varseqlen trt_skip_layernorm_fuse_pass"; } else { PADDLE_THROW( platform::errors::Fatal("Use transformer'varseqlen need config: " "use_varseqlen, set pos_id, set " "mask_id. Or not use varseqlen, do not set " - "pos_id, set mask_id. Please " + "pos_id. Please " "reconfig")); } } diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index 63637c25be..38ed95ce33 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -332,7 +332,7 @@ class FcOpConverter : public OpConverter { } // If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can // not add Shuffle layer in ernie's multihead. - if (x_dim.nbDims == 4 && x_num_col_dims == 1) { + if (x_dim.nbDims == 4 && x_dim.d[2] == 1 && x_dim.d[3] == 1) { if (enable_int8 || support_int8) { // add conv1x1 layer nvinfer1::DimsHW nv_ksize(1, 1); diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index 0515cb513d..0a238eadd9 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h" namespace paddle { namespace inference { @@ -87,7 +88,7 @@ class MultiheadMatMulOpConverter : public OpConverter { engine_->tensorrt_transformer_posid() != "" && engine_->tensorrt_transformer_maskid() != ""; if (engine_->with_dynamic_shape()) { - if (flag_varseqlen) { + if (engine_->tensorrt_transformer_maskid() != "") { if (engine_->precision() == AnalysisConfig::Precision::kFloat32) { PADDLE_THROW(platform::errors::Fatal( "use use_varseqlen must be int8 or half, not float32.")); @@ -98,8 +99,100 @@ class MultiheadMatMulOpConverter : public OpConverter { nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, static_cast(bias_data), static_cast(bias_t->numel())}; - auto max_seqlen_tensor = engine_->GetITensor("max_seqlen_tensor"); - auto pos_id_tensor = engine_->GetITensor("pos_id"); + + nvinfer1::ITensor* mask_tensor; + nvinfer1::ITensor* pos_id_tensor; + nvinfer1::ITensor* max_seqlen_tensor; + auto* new_input = input; + if (flag_varseqlen) { + mask_tensor = engine_->GetITensor("qkv_plugin_mask"); + pos_id_tensor = engine_->GetITensor("pos_id"); + max_seqlen_tensor = engine_->GetITensor("max_seqlen_tensor"); + } else { + auto* bias_qk_tensor = + engine_->GetITensor(op_desc.Input("BiasQK").front()); + auto bias_qk_dims = bias_qk_tensor->getDimensions(); + PADDLE_ENFORCE_EQ(bias_qk_dims.nbDims, + 4, + platform::errors::InvalidArgument( + "The rank of Multihead Matmul'BiasQK must be " + "4, but got rank is %d.", + bias_qk_dims.nbDims)); + + nvinfer1::Dims start_dims = bias_qk_dims; + start_dims.d[0] = 0; + start_dims.d[1] = 0; + start_dims.d[2] = 0; + start_dims.d[3] = 0; + nvinfer1::Dims size_dims = bias_qk_dims; + nvinfer1::Dims step_dims = bias_qk_dims; + step_dims.d[0] = 1; + step_dims.d[1] = 1; + step_dims.d[2] = 1; + step_dims.d[3] = 1; + auto* shape_tensor = Shape(bias_qk_tensor); + + // (b,n,m,m) -> (b,1,m,1) + std::vector size_vec_tensor; + size_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, 0)); + size_vec_tensor.push_back(Add1DConstantLayer(1)); + size_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, 2)); + size_vec_tensor.push_back(Add1DConstantLayer(1)); + auto* size_tensor = Concat(size_vec_tensor); + auto* slice_layer = TRT_ENGINE_ADD_LAYER(engine_, + Slice, + *bias_qk_tensor, + start_dims, + size_dims, + step_dims); + slice_layer->setInput(2, *size_tensor); + + // half -> bool + auto* cast_layer_0 = TRT_ENGINE_ADD_LAYER( + engine_, Identity, *slice_layer->getOutput(0)); + cast_layer_0->setOutputType(0, nvinfer1::DataType::kBOOL); + + // bool kNOT + auto* not_layer = + TRT_ENGINE_ADD_LAYER(engine_, + Unary, + *cast_layer_0->getOutput(0), + nvinfer1::UnaryOperation::kNOT); + + // bool -> int32 + auto* cast_layer_1 = + TRT_ENGINE_ADD_LAYER(engine_, Identity, *not_layer->getOutput(0)); + cast_layer_1->setOutputType(0, nvinfer1::DataType::kINT32); + + // Calculate the number of 1 : (b,1,m,1) -> (b) + uint32_t reduce_dim_0 = 0; + reduce_dim_0 |= 1 << 1; // 00000000000000000000000000000010 + reduce_dim_0 |= 1 << 2; // 00000000000000000000000000000110 + reduce_dim_0 |= 1 << 3; // 00000000000000000000000000001110 + bool keep_dim = false; + nvinfer1::ReduceOperation reduce_type = + nvinfer1::ReduceOperation::kSUM; + auto* reduce_sum_layer = + TRT_ENGINE_ADD_LAYER(engine_, + Reduce, + *cast_layer_1->getOutput(0), + reduce_type, + reduce_dim_0, + keep_dim); + std::vector inputs_transformer; + inputs_transformer.emplace_back(input); + inputs_transformer.emplace_back( + reduce_sum_layer->getOutput(0)); // (b,m) + plugin::TransformerInputConvertPlugin* plugin = + new plugin::TransformerInputConvertPlugin(); + nvinfer1::ILayer* transformer_input_layer = engine_->AddDynamicPlugin( + inputs_transformer.data(), inputs_transformer.size(), plugin); + + new_input = transformer_input_layer->getOutput(0); + mask_tensor = transformer_input_layer->getOutput(1); + pos_id_tensor = transformer_input_layer->getOutput(2); + max_seqlen_tensor = transformer_input_layer->getOutput(3); + } if (engine_->with_interleaved()) { VLOG(4) << "fused multihead_matmul op: use_varseqlen and " "with_interleaved"; @@ -111,7 +204,7 @@ class MultiheadMatMulOpConverter : public OpConverter { float dp_probs = 1.0 / 127.0; nvinfer1::DimsHW nv_ksize(1, 1); fc_layer = TRT_ENGINE_ADD_LAYER( - engine_, Convolution, *input, n, nv_ksize, weight, bias); + engine_, Convolution, *new_input, n, nv_ksize, weight, bias); fc_layer->setName( ("Multihead: Convolution/FullyConnected: (Output: " + output_name + ")") @@ -220,10 +313,10 @@ class MultiheadMatMulOpConverter : public OpConverter { if (op_desc.HasAttr("Input_scale")) { nvinfer1::DimsHW nv_ksize(1, 1); fc_layer = TRT_ENGINE_ADD_LAYER( - engine_, Convolution, *input, n, nv_ksize, weight, bias); + engine_, Convolution, *new_input, n, nv_ksize, weight, bias); } else { fc_layer = TRT_ENGINE_ADD_LAYER( - engine_, FullyConnected, *input, n, weight, bias); + engine_, FullyConnected, *new_input, n, weight, bias); } if (op_desc.HasAttr("fc_out_threshold")) { @@ -282,14 +375,28 @@ class MultiheadMatMulOpConverter : public OpConverter { std::vector plugin_inputs; plugin_inputs.emplace_back(fc_layer->getOutput(0)); - plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask")); + plugin_inputs.emplace_back(mask_tensor); plugin_inputs.emplace_back(pos_id_tensor); plugin_inputs.emplace_back( max_seqlen_tensor); // max_seqlen, eval_placeholder_3 auto plugin_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin); - layer = plugin_layer; + + // recover no_varlen output + if (!flag_varseqlen) { + std::vector output_transformer; + output_transformer.emplace_back(plugin_layer->getOutput(0)); + output_transformer.emplace_back(input); + output_transformer.emplace_back(pos_id_tensor); + plugin::TransformerOutputConvertPlugin* plugin = + new plugin::TransformerOutputConvertPlugin(); + nvinfer1::ILayer* transformer_output_layer = + engine_->AddDynamicPlugin(output_transformer.data(), + output_transformer.size(), + plugin); + layer = transformer_output_layer; + } } } else { if (input_dims.d[1] <= 384 && !bias_qk_attr && diff --git a/paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc b/paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc index a9b80f076a..37257b9564 100644 --- a/paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/transformer_input_convert_op.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" -#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index a72880780d..2ecb8c8c71 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -26,7 +26,7 @@ list( deformable_conv_op_plugin.cu matmul_op_int8_plugin.cu multihead_matmul_roformer_plugin.cu - transformer_input_convert_plugin.cu + transformer_input_output_convert_plugin.cu remove_padding_plugin.cu recover_padding_plugin.cu c_allreduce_op_plugin.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu index a18c0d0c72..ec874f71ad 100644 --- a/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu @@ -105,7 +105,6 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT { - const auto input_desc = inputDesc[0]; const half* input0 = static_cast(inputs[0]); const int32_t* input1 = static_cast(inputs[1]); // pos_id_tensor diff --git a/paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu deleted file mode 100644 index c9e4852584..0000000000 --- a/paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.cu +++ /dev/null @@ -1,122 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h" - -namespace paddle { -namespace inference { -namespace tensorrt { -namespace plugin { - -__global__ void TransformerInputConvertKernel(const int64_t* input, - int32_t* output0) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - __shared__ int32_t shared_data; - if (threadIdx.x == static_cast(input[tid])) { - atomicAdd(&shared_data, 1); - } - output0[0] = 0; - output0[blockIdx.x + 1] = shared_data; - __syncthreads(); - for (int i = 0; i < blockDim.x; ++i) { - output0[i + 1] += output0[i]; - } -} - -nvinfer1::DataType TransformerInputConvertPlugin::getOutputDataType( - int index, - const nvinfer1::DataType* input_types, - int nb_inputs) const TRT_NOEXCEPT { - return nvinfer1::DataType::kINT32; -} - -nvinfer1::DimsExprs TransformerInputConvertPlugin::getOutputDimensions( - int outputIndex, - const nvinfer1::DimsExprs* inputs, - int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { - nvinfer1::DimsExprs output_dims{}; - output_dims.nbDims = 1; - if (outputIndex == 0) { // PosId - const auto* one = exprBuilder.constant(1); - output_dims.d[0] = exprBuilder.operation( - nvinfer1::DimensionOperation::kSUM, *inputs[0].d[0], *one); - } else { // MaxSeqlen - output_dims.d[0] = inputs[0].d[1]; - } - return output_dims; -} - -bool TransformerInputConvertPlugin::supportsFormatCombination( - int pos, - const nvinfer1::PluginTensorDesc* inOut, - int nbInputs, - int nbOutputs) TRT_NOEXCEPT { - PADDLE_ENFORCE_EQ(nbInputs, - 1, - platform::errors::InvalidArgument("Must have 1 inputs, " - "but got %d input(s). ", - nbInputs)); - PADDLE_ENFORCE_EQ(nbOutputs, - getNbOutputs(), - platform::errors::InvalidArgument("Must have 2 output, " - "but got %d output(s). ", - nbOutputs)); - if (pos == 0) { // input - return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; - } else { // output0, output1 - return inOut[pos].type == nvinfer1::DataType::kINT32 && - inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; - } -} - -void TransformerInputConvertPlugin::configurePlugin( - const nvinfer1::DynamicPluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* outputs, - int nbOutputs) TRT_NOEXCEPT {} - -void TransformerInputConvertPlugin::attachToContext( - cudnnContext* cudnnContext, - cublasContext* cublasContext, - nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {} - -void TransformerInputConvertPlugin::detachFromContext() TRT_NOEXCEPT {} - -void TransformerInputConvertPlugin::terminate() TRT_NOEXCEPT {} - -int TransformerInputConvertPlugin::enqueue( - const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) TRT_NOEXCEPT { - const auto input_desc = inputDesc[0]; - const int64_t* input = static_cast(inputs[0]); - int32_t* output0 = static_cast(outputs[0]); // PosId - // int32_t* output1 = static_cast(outputs[1]); // MaxSeqlen - - const int32_t num_blocks = input_desc.dims.d[0]; // batchs - const int32_t num_threads = input_desc.dims.d[1]; // max sequnce length - - TransformerInputConvertKernel<<>>( - input, output0); - return cudaGetLastError() != cudaSuccess; -} - -} // namespace plugin -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.cu new file mode 100644 index 0000000000..39e2a0b422 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.cu @@ -0,0 +1,356 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h" +#include "cub/cub.cuh" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +__global__ void remove_padding_kernel(const half* input0, + const int32_t* input1, + half* output) { + int word_id = blockIdx.x * gridDim.y + blockIdx.y; + int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x]; + if (blockIdx.y < seqence_length) { + output[(input1[blockIdx.x] + blockIdx.y) * gridDim.z * blockDim.x + + blockIdx.z * blockDim.x + threadIdx.x] = + input0[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + + threadIdx.x]; + } +} + +__global__ void recover_padding_kernel(const half* input0, + const int32_t* input1, + half* output) { + int word_id = blockIdx.x * gridDim.y + blockIdx.y; + int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x]; + if (blockIdx.y < seqence_length) { + output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + + threadIdx.x] = + input0[(input1[blockIdx.x] + blockIdx.y) * gridDim.z * blockDim.x + + blockIdx.z * blockDim.x + threadIdx.x]; + } else { + output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + + threadIdx.x] = 0; + } +} + +nvinfer1::DataType TransformerInputConvertPlugin::getOutputDataType( + int index, + const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT { + if (index == 0) { // new input + return nvinfer1::DataType::kHALF; + } else if (index == 1) { // mask + return nvinfer1::DataType::kHALF; + } else if (index == 2) { // pos id + return nvinfer1::DataType::kINT32; + } else if (index == 3) { // max_seqlen_tensor + return nvinfer1::DataType::kHALF; + } +} + +nvinfer1::DimsExprs TransformerInputConvertPlugin::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { + constexpr size_t threadsPerCta384 = 1 * 8 * 32; + constexpr size_t xmmasM384 = 24; + constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384; + int32_t maskSize_ = packedMaskSize384; + auto maskSize = exprBuilder.constant(maskSize_); + auto fp16maskSize = exprBuilder.operation( + nvinfer1::DimensionOperation::kPROD, *maskSize, *exprBuilder.constant(2)); + + auto one = exprBuilder.constant(1); + auto B = inputs[0].d[0]; + auto MaxLength = inputs[0].d[1]; + auto Hidden = inputs[0].d[2]; + + nvinfer1::DimsExprs output_dims; + if (outputIndex == 0) { // new input + output_dims.nbDims = 4; + output_dims.d[0] = exprBuilder.operation( + nvinfer1::DimensionOperation::kPROD, *B, *MaxLength); + output_dims.d[1] = Hidden; + output_dims.d[2] = exprBuilder.constant(1); + output_dims.d[3] = exprBuilder.constant(1); + } else if (outputIndex == 1) { // mask + output_dims.nbDims = 2; + output_dims.d[0] = B; + output_dims.d[1] = fp16maskSize; + } else if (outputIndex == 2) { // pos id + output_dims.nbDims = 1; + output_dims.d[0] = + exprBuilder.operation(nvinfer1::DimensionOperation::kSUM, *B, *one); + } else if (outputIndex == 3) { // max_seqlen_tensor + output_dims.nbDims = 1; + output_dims.d[0] = MaxLength; + } + return output_dims; +} + +bool TransformerInputConvertPlugin::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(nbInputs, + 2, + platform::errors::InvalidArgument( + "TransformerInputConvertPlugin must have 2 inputs, " + "but got %d input(s). ", + nbInputs)); + PADDLE_ENFORCE_EQ(nbOutputs, + 4, + platform::errors::InvalidArgument( + "TransformerInputConvertPlugin must have 4 outputs, " + "but got %d output(s). ", + nbOutputs)); + if (pos == 0) { // input + return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && + inOut[pos].type == nvinfer1::DataType::kHALF; + } else if (pos == 1) { // reducesum_qk_bias + return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && + inOut[pos].type == nvinfer1::DataType::kINT32; + } else if (pos == 2) { // new input + return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && + inOut[pos].type == nvinfer1::DataType::kHALF; + } else if (pos == 3) { // mask + return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && + inOut[pos].type == nvinfer1::DataType::kHALF; + } else if (pos == 4) { // pos id + return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && + inOut[pos].type == nvinfer1::DataType::kINT32; + } else if (pos == 5) { // max_seqlen_tensor + return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && + inOut[pos].type == nvinfer1::DataType::kHALF; + } +} + +void TransformerInputConvertPlugin::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + +void TransformerInputConvertPlugin::attachToContext( + cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {} + +void TransformerInputConvertPlugin::detachFromContext() TRT_NOEXCEPT {} + +void TransformerInputConvertPlugin::terminate() TRT_NOEXCEPT {} + +int TransformerInputConvertPlugin::enqueue( + const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT { + // input(no_varlen), reducesum_qk_bias, input(varlen), mask, pos_id, + // max_seqlen_tensor + const half* input0 = static_cast(inputs[0]); // input(no_varlen) + const int32_t* input1 = + static_cast(inputs[1]); // reducesum_qk_bias + half* output0 = static_cast(outputs[0]); // input(varlen) + int32_t* output2 = static_cast(outputs[2]); // pos_id + const auto input0_desc = inputDesc[0]; + const int32_t B = input0_desc.dims.d[0]; // batchs + const int32_t MaxLength = input0_desc.dims.d[1]; // max token length + const int32_t HiddenSize = input0_desc.dims.d[2]; // hidden size + + // Determine temporary device storage requirements + void* d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::ExclusiveSum( + d_temp_storage, temp_storage_bytes, input1, output2, B + 1); + // Allocate temporary storage + cudaMalloc(&d_temp_storage, temp_storage_bytes); + + // Run exclusive prefix sum + cub::DeviceScan::ExclusiveSum( + d_temp_storage, temp_storage_bytes, input1, output2, B + 1); + const int32_t vector_length = HiddenSize; + int32_t num_threads; + if (vector_length < 1024) { + num_threads = vector_length; + } else { + if (vector_length % 512 == 0) { + num_threads = 512; + } else if (vector_length % 256 == 0) { + num_threads = 256; + } else if (vector_length % 128 == 0) { + num_threads = 128; + } else if (vector_length % 64 == 0) { + num_threads = 64; + } else if (vector_length % 32 == 0) { + num_threads = 32; + } else if (vector_length % 16 == 0) { + num_threads = 16; + } else if (vector_length % 8 == 0) { + num_threads = 8; + } else if (vector_length % 4 == 0) { + num_threads = 4; + } else if (vector_length % 2 == 0) { + num_threads = 2; + } else { + num_threads = 1; + } + } + const dim3 num_blocks( + B, + MaxLength, + vector_length / + num_threads); // batchs, max sequnce length, input0.dims.d[2]/* + remove_padding_kernel<<>>( + input0, output2, output0); // input(no_varlen), pos_id, input(varlen) + return cudaGetLastError() != cudaSuccess; +} + +nvinfer1::DataType TransformerOutputConvertPlugin::getOutputDataType( + int index, + const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT { + if (index == 0) { + return nvinfer1::DataType::kHALF; + } +} + +nvinfer1::DimsExprs TransformerOutputConvertPlugin::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { + nvinfer1::DimsExprs output_dims; + if (outputIndex == 0) { + output_dims = inputs[1]; + } + return output_dims; +} + +bool TransformerOutputConvertPlugin::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(nbInputs, + 3, + platform::errors::InvalidArgument( + "TransformerOutputConvertPlugin must have 3 inputs, " + "but got %d input(s). ", + nbInputs)); + PADDLE_ENFORCE_EQ(nbOutputs, + 1, + platform::errors::InvalidArgument( + "TransformerOutputConvertPlugin must have 1 output, " + "but got %d output(s). ", + nbOutputs)); + if (pos == 0) { // qkv plugin output(varlen) + return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && + inOut[pos].type == nvinfer1::DataType::kHALF; + } else if (pos == 1) { // qkv plugin input(no_varlen) + return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && + inOut[pos].type == nvinfer1::DataType::kHALF; + } else if (pos == 2) { // pos id + return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && + inOut[pos].type == nvinfer1::DataType::kINT32; + } else if (pos == 3) { // qkv plugin output(no_varlen) + return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR && + inOut[pos].type == nvinfer1::DataType::kHALF; + } +} + +void TransformerOutputConvertPlugin::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + +void TransformerOutputConvertPlugin::attachToContext( + cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {} + +void TransformerOutputConvertPlugin::detachFromContext() TRT_NOEXCEPT {} + +void TransformerOutputConvertPlugin::terminate() TRT_NOEXCEPT {} + +int TransformerOutputConvertPlugin::enqueue( + const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT { + const half* input0 = + static_cast(inputs[0]); // qkv plugin output(varlen) + const half* input1 = + static_cast(inputs[1]); // qkv plugin input(no_varlen) + const int32_t* input2 = static_cast(inputs[2]); // pos id + half* output = + static_cast(outputs[0]); // qkv plugin output(no_varlen) + const auto input1_desc = inputDesc[1]; + const int32_t B = input1_desc.dims.d[0]; // batchs + const int32_t MaxLength = input1_desc.dims.d[1]; // max token length + const int32_t HiddenSize = input1_desc.dims.d[2]; // hidden size + + const int32_t vector_length = HiddenSize; + int32_t num_threads; + if (vector_length < 1024) { + num_threads = vector_length; + } else { + if (vector_length % 512 == 0) { + num_threads = 512; + } else if (vector_length % 256 == 0) { + num_threads = 256; + } else if (vector_length % 128 == 0) { + num_threads = 128; + } else if (vector_length % 64 == 0) { + num_threads = 64; + } else if (vector_length % 32 == 0) { + num_threads = 32; + } else if (vector_length % 16 == 0) { + num_threads = 16; + } else if (vector_length % 8 == 0) { + num_threads = 8; + } else if (vector_length % 4 == 0) { + num_threads = 4; + } else if (vector_length % 2 == 0) { + num_threads = 2; + } else { + num_threads = 1; + } + } + const dim3 num_blocks( + B, + MaxLength, + vector_length / num_threads); // batchs, max sequnce length + // (mask_id.dims.d[1]), + // input.dims.d[1]/* + recover_padding_kernel<<>>( + input0, input2, output); + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h b/paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h similarity index 53% rename from paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h rename to paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h index 43ca34c427..80a5bfa5b1 100644 --- a/paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h @@ -40,14 +40,14 @@ class TransformerInputConvertPlugin : public DynamicPluginTensorRT { return "transformer_input_convert_plugin"; } - int getNbOutputs() const TRT_NOEXCEPT override { return 2; } + int getNbOutputs() const TRT_NOEXCEPT override { return 4; } int initialize() TRT_NOEXCEPT { return 0; } void terminate() TRT_NOEXCEPT; nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) + nvinfer1::IExprBuilder& exprBuilder) // NOLINT TRT_NOEXCEPT override; bool supportsFormatCombination(int pos, @@ -134,7 +134,120 @@ class TransformerInputConvertPluginCreator : public nvinfer1::IPluginCreator { std::string plugin_name_; nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; }; + +class TransformerOutputConvertPlugin : public DynamicPluginTensorRT { + public: + TransformerOutputConvertPlugin() {} + + TransformerOutputConvertPlugin(void const* serial_data, + size_t serial_length) {} + + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + TransformerOutputConvertPlugin* ptr = new TransformerOutputConvertPlugin(); + return ptr; + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "transformer_output_convert_plugin"; + } + + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + + int initialize() TRT_NOEXCEPT { return 0; } + void terminate() TRT_NOEXCEPT; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) // NOLINT + TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + void attachToContext(cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) + TRT_NOEXCEPT override; + + void detachFromContext() TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + protected: + size_t getSerializationSize() const TRT_NOEXCEPT override { return 0; } + + void serialize(void* buffer) const TRT_NOEXCEPT override {} +}; + +class TransformerOutputConvertPluginCreator : public nvinfer1::IPluginCreator { + public: + TransformerOutputConvertPluginCreator() {} + const char* getPluginName() const TRT_NOEXCEPT override { + return "transformer_output_convert_plugin"; + } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* plugin_field) + TRT_NOEXCEPT override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + void const* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + TransformerOutputConvertPlugin* obj = + new TransformerOutputConvertPlugin(serial_data, serial_length); + obj->setPluginNamespace(name); + return obj; + } + + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; +}; + REGISTER_TRT_PLUGIN_V2(TransformerInputConvertPluginCreator); +REGISTER_TRT_PLUGIN_V2(TransformerOutputConvertPluginCreator); } // namespace plugin } // namespace tensorrt } // namespace inference -- GitLab