From 297827280125ff3e18ea4e1e12194f96dfd83a6d Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 24 Nov 2022 15:06:11 +0800 Subject: [PATCH] [Paddle Inference]optimize token prune for Paddle-TensorRT (#48241) * optimize token prune --- .../ir/remove_padding_recover_padding_pass.cc | 57 +++ .../ir/remove_padding_recover_padding_pass.h | 12 +- .../tensorrt/convert/fused_token_prune_op.cc | 15 +- .../plugin/fused_token_prune_op_plugin.cu | 371 ++++++++++++------ .../plugin/fused_token_prune_op_plugin.h | 33 +- .../tensorrt/plugin/recover_padding_plugin.cu | 59 +-- .../tensorrt/plugin/remove_padding_plugin.cu | 57 +-- .../plugin/test_fused_token_prune_plugin.cc | 3 - 8 files changed, 422 insertions(+), 185 deletions(-) diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc index 237cb3bd3d..5127c5934c 100644 --- a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc @@ -131,6 +131,21 @@ void Activation::operator()() { // Add links for activation op. activation_op->LinksFrom({activation_input}).LinksTo({activation_out}); } + +void FusedTokenPrune::operator()() { + // Create nodes for fused_token_prune. + auto* fused_token_prune_input = + pattern->NewNode(fused_token_prune_input_repr()) + ->assert_is_op_input("fused_token_prune", "X"); + auto* fused_token_prune_op = pattern->NewNode(fused_token_prune_op_repr()) + ->assert_is_op("fused_token_prune"); + auto* fused_token_prune_output = + pattern->NewNode(fused_token_prune_output_repr()) + ->assert_is_op_output("fused_token_prune", "SlimmedX"); + + fused_token_prune_op->LinksFrom({fused_token_prune_input}) + .LinksTo({fused_token_prune_output}); +} } // namespace patterns void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { @@ -563,6 +578,48 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { }; gpd6(graph, handler6); + GraphPatternDetector gpd7; + patterns::FusedTokenPrune fused_token_prune( + gpd7.mutable_pattern(), "remove_padding_recover_padding_pass"); + fused_token_prune(); + + auto handler7 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(3) << "remove_padding_recover_padding_pass for transformer: " + "fused_token_prune"; + + GET_IR_NODE_FROM_SUBGRAPH( + fused_token_prune_input, fused_token_prune_input, fused_token_prune); + GET_IR_NODE_FROM_SUBGRAPH( + fused_token_prune_op, fused_token_prune_op, fused_token_prune); + GET_IR_NODE_FROM_SUBGRAPH( + fused_token_prune_output, fused_token_prune_output, fused_token_prune); + + std::vector fused_token_prune_input_shape = + fused_token_prune_input->Var()->GetShape(); + check_flag = true; + if (fused_token_prune_input_shape.size() != + multihead_matmul_input_shape.size()) { + check_flag = false; + VLOG(3) << "Transformer model remove_padding shape check failed, return " + "remove_padding pass."; + return; + } + for (size_t i = 0; i < fused_token_prune_input_shape.size(); ++i) { + if (fused_token_prune_input_shape[i] != multihead_matmul_input_shape[i]) { + check_flag = false; + } + } + if (!check_flag) { + VLOG(3) << "Transformer model remove_padding shape check failed, return " + "remove_padding pass."; + return; + } + insert_recover_padding_op(fused_token_prune_op, fused_token_prune_output); + found_subgraph_count++; + }; + gpd7(graph, handler7); + AddStatis(found_subgraph_count); } diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h index f93ee4bc7c..ff04dc5532 100644 --- a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h @@ -95,7 +95,6 @@ struct Fc : public PatternBase { PATTERN_DECL_NODE(fc_input); PATTERN_DECL_NODE(fc_op); - PATTERN_DECL_NODE(fc_out); }; struct Activation : public PatternBase { @@ -108,6 +107,17 @@ struct Activation : public PatternBase { PATTERN_DECL_NODE(activation_op); PATTERN_DECL_NODE(activation_out); }; + +struct FusedTokenPrune : public PatternBase { + FusedTokenPrune(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "fused_token_prune") {} + + void operator()(); + + PATTERN_DECL_NODE(fused_token_prune_input); + PATTERN_DECL_NODE(fused_token_prune_op); + PATTERN_DECL_NODE(fused_token_prune_output); +}; } // namespace patterns class RemovePaddingRecoverPaddingPass : public FusePassBase { diff --git a/paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc b/paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc index dba0d003c0..4832b1fad1 100644 --- a/paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc @@ -52,8 +52,21 @@ class FusedTokenPruneOpConverter : public OpConverter { auto* word_id = engine_->GetITensor("word_id"); auto* pos_id = engine_->GetITensor("pos_id"); auto* mask_id = engine_->GetITensor("mask_id"); + + // reduce_sum: (-1,headsize,token_length,token_length) -> + // (-1,token_length) + uint32_t reduce_dim = 0; + reduce_dim |= 1 << 1; // 00000000000000000000000000000010 + reduce_dim |= 1 << 2; // 00000000000000000000000000000110 + bool keep_dim = false; + nvinfer1::ReduceOperation reduce_type = nvinfer1::ReduceOperation::kSUM; + auto* reduce_sum_layer = TRT_ENGINE_ADD_LAYER( + engine_, Reduce, *Attn, reduce_type, reduce_dim, keep_dim); + // reduce_sum_layer->getOutput(0)->setType(reduce_sum_layer->getInput(0)->getType()); + + auto* Reduced = reduce_sum_layer->getOutput(0); std::vector itensors = { - Attn, X, Mask, NewMask, word_id, pos_id, mask_id}; + Reduced, X, Mask, NewMask, word_id, pos_id, mask_id}; layer = engine_->AddDynamicPlugin(itensors.data(), 7, plugin); layer->getOutput(0)->setName(output_name.c_str()); diff --git a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu index fe011422c1..b0c800d31b 100644 --- a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu @@ -31,19 +31,15 @@ namespace inference { namespace tensorrt { namespace plugin { -#if IS_TRT_VERSION_GE(6000) - template __global__ void ElementwiseMask(const T* a, const T* b, T* res, int num_elements) { -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) auto tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= num_elements) return; const T zero = 0; res[tid] = b[tid] >= zero ? a[tid] : zero; -#endif } template @@ -123,7 +119,6 @@ __global__ void ReduceSum2( template <> __global__ void ReduceSum2( const half* src, half* dst, int bsz, int nb_head, int max_seq_len) { -#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) int tid = threadIdx.x; int bid = blockIdx.x; int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len); @@ -155,7 +150,6 @@ __global__ void ReduceSum2( static_cast(bsz * max_seq_len), static_cast(res_half[0])); } -#endif } template @@ -177,14 +171,81 @@ __global__ void TakeAlongAxis(const T* src, } } -__global__ void pos_id_prune_kernel(const int32_t* src, - int32_t* dst, - int pos_nums, - float scale) { - dst[0] = 0; - for (int i = 1; i < pos_nums; i++) { - dst[i] = - dst[i - 1] + max(static_cast((src[i] - src[i - 1]) * scale), 2); +__global__ void compute_token_length(const int32_t* src, + int32_t* dst, + float scale) { + int32_t it = threadIdx.x; + dst[it] = max(static_cast((src[it + 1] - src[it]) * scale), 1); +} + +__global__ void fill_index_padding_score(int32_t* token_index, + const half* scores, + int32_t scores_size, + half* padding_scores) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + token_index[tid] = threadIdx.x; + if (tid < scores_size) { + padding_scores[tid] = scores[tid]; + } else { + padding_scores[tid] = 0; + } +} + +template +__global__ void general_topk_pair_sort(T* in_keys, int32_t* in_out_values) { + typedef cub::BlockRadixSort + BlockRadixSort; + typedef cub:: + BlockLoad + BlockLoadKey; + typedef cub:: + BlockLoad + BlockLoadValue; + typedef cub:: + BlockStore + BlockStoreKey; + typedef cub::BlockStore + BlockStoreValue; + + __shared__ union { + typename BlockRadixSort::TempStorage sort; + typename BlockLoadKey::TempStorage loadkey; + typename BlockLoadValue::TempStorage loadvalue; + typename BlockStoreKey::TempStorage storekey; + typename BlockStoreValue::TempStorage storevalue; + } temp_storage; + + int block_offset = blockIdx.x * BLOCK_THREADS * ITEMS_PER_THREAD; + + T thread_keys[ITEMS_PER_THREAD]; + int thread_values[ITEMS_PER_THREAD]; + BlockLoadKey(temp_storage.loadkey).Load(in_keys + block_offset, thread_keys); + BlockLoadValue(temp_storage.loadvalue) + .Load(in_out_values + block_offset, thread_values); + __syncthreads(); + + BlockRadixSort(temp_storage.sort).SortDescending(thread_keys, thread_values); + __syncthreads(); + + BlockStoreValue(temp_storage.storevalue) + .Store(in_out_values + block_offset, thread_values); +} + +__global__ void varlen_prune_token(const half* tokens, + const int32_t* token_pos, + const int32_t* token_index, + half* output) { + int batch = blockIdx.x; + int token_it = batch * gridDim.y + blockIdx.y; + int pre_value_it = + token_it * gridDim.z * blockDim.x + blockIdx.z * blockDim.x + threadIdx.x; + + if (token_index[token_it] < token_pos[batch + 1] - token_pos[batch]) { + output[(token_index[token_it] + token_pos[batch]) * gridDim.z * blockDim.x + + blockIdx.z * blockDim.x + threadIdx.x] = tokens[pre_value_it]; } } @@ -195,9 +256,29 @@ nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions( nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT { auto x_dims = inputs[1], new_mask_dims = inputs[3]; if (flag_varseqlen_) { + // max sum of seqlen: ceil(sum / scale) + n -1 >= for(i=0;i>>(input, output, pos_nums, scale); -} - int FusedTokenPrunePluginDynamic::enqueue( const nvinfer1::PluginTensorDesc* input_desc, const nvinfer1::PluginTensorDesc* output_desc, @@ -572,73 +621,153 @@ int FusedTokenPrunePluginDynamic::enqueue( void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT { - auto input_type = input_desc[0].type; - auto attn_dims = input_desc[0].dims; - auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1], - max_seq_len = attn_dims.d[2]; - int device_id; - cudaGetDevice(&device_id); - - if (input_type == nvinfer1::DataType::kFLOAT) { - VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp32"; - - float max = std::numeric_limits::max(); - - enqueueImpl(input_desc, - output_desc, - inputs, - outputs, - workspace, - stream, - device_id, - max, - keep_first_token_, - keep_order_); - - } else if (input_type == nvinfer1::DataType::kHALF) { -#ifdef TRT_PLUGIN_FP16_AVALIABLE - VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp16"; - - half max = 65504.0; - - enqueueImpl(input_desc, - output_desc, - inputs, - outputs, - workspace, - stream, - device_id, - max, - keep_first_token_, - keep_order_); - -#else - PADDLE_THROW(platform::errors::Fatal( - "The Ernie(Bert) TensorRT Plugin should be " - "complied with CUDA version >= 10.0 when running with fp16. " - "Please recomplie it or try to use fp32 by set " - "config.SetTRTDynamicShapeInfo(min_input_shape, " - "max_input_shape, opt_input_shape, true")); -#endif - } else { - PADDLE_THROW( - platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type " - "should be float or half.")); - } if (flag_varseqlen_) { + if (!(input_desc[0].type == nvinfer1::DataType::kHALF && + input_desc[1].type == nvinfer1::DataType::kHALF)) { + PADDLE_THROW( + platform::errors::InvalidArgument("Token_prune'type must half")); + } float scale = static_cast(input_desc[3].dims.d[2]) / input_desc[6].dims.d[1]; - // outputs[2]=inputs[4]; // word_id - const int32_t* inputs5 = static_cast(inputs[5]); - int32_t* outputs3 = static_cast(outputs[3]); - pos_id_prune( - inputs5, outputs3, input_desc[5].dims.d[0], scale, stream); // pos_id - // outputs[4]=inputs[6]; // new_mask + const int32_t* inputs5 = + static_cast(inputs[5]); // pre pos id + int32_t* outputs3 = static_cast(outputs[3]); // new pos id + half* outputs0 = static_cast(outputs[0]); + + const int32_t B = input_desc[1].dims.d[0]; // batchs + const int32_t max_sequnce_length = + input_desc[1].dims.d[1]; // max sequnce length + const int32_t length = input_desc[1].dims.d[2]; // vector length + const half* scores = static_cast(inputs[0]); // reduce sum + const half* tokens = static_cast(inputs[1]); + const int32_t scores_size = B * max_sequnce_length; + int32_t padding_token_length; + if (max_sequnce_length <= 128) { + padding_token_length = 128; + } else if (max_sequnce_length <= 256) { + padding_token_length = 256; + } else if (max_sequnce_length <= 384) { + padding_token_length = 384; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Token_prune'token_length must <= 384")); + } + + // 1. Compute the token length after pruning. + compute_token_length<<<1, B, 0, stream>>>( + inputs5, pruned_token_lengths_, scale); + + fill_index_padding_score<<>>( + token_index_, scores, scores_size, padding_scores_); + + // Determine temporary device storage requirements + void* d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::ExclusiveSum(d_temp_storage, + temp_storage_bytes, + pruned_token_lengths_, + outputs3, + B + 1); + // Allocate temporary storage + cudaMalloc(&d_temp_storage, temp_storage_bytes); + + // Run exclusive prefix sum + cub::DeviceScan::ExclusiveSum(d_temp_storage, + temp_storage_bytes, + pruned_token_lengths_, + outputs3, + B + 1); + + if (padding_token_length == 128) { + general_topk_pair_sort + <<>>(padding_scores_, token_index_); // 128 + } else if (padding_token_length == 256) { + general_topk_pair_sort + <<>>(padding_scores_, token_index_); // 256 + } else { + general_topk_pair_sort + <<>>(padding_scores_, token_index_); // 384 + } + + int32_t num_threads; + if (length < 1024) { + num_threads = length; + } else { + if (length % 512 == 0) { + num_threads = 512; + } else if (length % 256 == 0) { + num_threads = 256; + } else if (length % 128 == 0) { + num_threads = 128; + } else if (length % 64 == 0) { + num_threads = 64; + } else if (length % 32 == 0) { + num_threads = 32; + } else if (length % 16 == 0) { + num_threads = 16; + } else if (length % 8 == 0) { + num_threads = 8; + } else if (length % 4 == 0) { + num_threads = 4; + } else if (length % 2 == 0) { + num_threads = 2; + } else { + num_threads = 1; + } + } + const dim3 num_blocks( + B, + max_sequnce_length, + length / num_threads); // batchs, max_sequnce_length, vector_ength/*** + varlen_prune_token<<>>( + tokens, outputs3, token_index_, outputs0); + } else { + auto input_type = input_desc[0].type; + auto attn_dims = input_desc[0].dims; + auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1], + max_seq_len = attn_dims.d[2]; + int device_id; + cudaGetDevice(&device_id); + + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp32"; + + float max = std::numeric_limits::max(); + + enqueueImpl(input_desc, + output_desc, + inputs, + outputs, + workspace, + stream, + device_id, + max, + keep_first_token_, + keep_order_); + + } else if (input_type == nvinfer1::DataType::kHALF) { + VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp16"; + + half max = 65504.0; + enqueueImpl(input_desc, + output_desc, + inputs, + outputs, + workspace, + stream, + device_id, + max, + keep_first_token_, + keep_order_); + } else { + PADDLE_THROW( + platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type " + "should be float or half.")); + } } return cudaGetLastError() != cudaSuccess; } -#endif } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h index 0b32e8a552..4c9c24c59a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h @@ -16,6 +16,7 @@ #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace inference { @@ -30,11 +31,10 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { bool keep_first_token, bool keep_order, bool flag_varseqlen) - : keep_first_token_(keep_first_token), + : with_fp16_(with_fp16), + keep_first_token_(keep_first_token), keep_order_(keep_order), - flag_varseqlen_(flag_varseqlen) { - with_fp16_ = with_fp16; - } + flag_varseqlen_(flag_varseqlen) {} FusedTokenPrunePluginDynamic(void const* serial_data, size_t serial_length) { DeserializeValue(&serial_data, &serial_length, &with_fp16_); DeserializeValue(&serial_data, &serial_length, &keep_first_token_); @@ -42,8 +42,14 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { DeserializeValue(&serial_data, &serial_length, &flag_varseqlen_); } nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { - return new FusedTokenPrunePluginDynamic( + FusedTokenPrunePluginDynamic* ptr = new FusedTokenPrunePluginDynamic( with_fp16_, keep_first_token_, keep_order_, flag_varseqlen_); + ptr->max_batchs_ = max_batchs_; + ptr->max_token_length_ = max_token_length_; + ptr->pruned_token_lengths_ = pruned_token_lengths_; + ptr->token_index_ = token_index_; + ptr->padding_scores_ = padding_scores_; + return ptr; } const char* getPluginType() const TRT_NOEXCEPT override { @@ -84,7 +90,16 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nb_inputs, const nvinfer1::DynamicPluginTensorDesc* out, - int nb_outputs) TRT_NOEXCEPT override {} + int nb_outputs) TRT_NOEXCEPT override { + max_batchs_ = in[1].max.d[0]; + max_token_length_ = in[1].max.d[1]; + PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc(&pruned_token_lengths_, + (max_batchs_ + 1) * sizeof(int32_t))); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc( + &token_index_, max_batchs_ * max_token_length_ * sizeof(int32_t))); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMalloc( + &padding_scores_, max_batchs_ * max_token_length_ * sizeof(half))); + } size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nb_inputs, @@ -106,9 +121,15 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT { void destroy() TRT_NOEXCEPT override { delete this; } private: + bool with_fp16_; bool keep_first_token_; bool keep_order_; bool flag_varseqlen_; + int32_t* pruned_token_lengths_; + int32_t* token_index_; + int32_t max_batchs_; + int32_t max_token_length_; + half* padding_scores_; }; class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator { diff --git a/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu index c6be871709..50884b79d8 100644 --- a/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu @@ -19,9 +19,9 @@ namespace inference { namespace tensorrt { namespace plugin { -__global__ void RecoverPaddingKernel(const float* input0, +__global__ void RecoverPaddingKernel(const half* input0, const int32_t* input1, - float* output) { + half* output) { int word_id = blockIdx.x * gridDim.y + blockIdx.y; int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x]; if (blockIdx.y < seqence_length) { @@ -79,7 +79,7 @@ bool RecoverPaddingPlugin::supportsFormatCombination( return inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; } else { - return inOut[pos].type == nvinfer1::DataType::kFLOAT && + return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; } // return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format @@ -114,38 +114,43 @@ int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const auto input0_desc = inputDesc[0]; const auto input1_desc = inputDesc[1]; const auto input2_desc = inputDesc[2]; - const float* input0 = static_cast(inputs[0]); + const half* input0 = static_cast(inputs[0]); const int32_t* input1 = static_cast(inputs[1]); // pos_id_tensor - float* output = static_cast(outputs[0]); + half* output = static_cast(outputs[0]); + const int32_t vector_length = input0_desc.dims.d[1]; int32_t num_threads; - if (input0_desc.dims.d[1] % 512 == 0) { - num_threads = 512; - } else if (input0_desc.dims.d[1] % 256 == 0) { - num_threads = 256; - } else if (input0_desc.dims.d[1] % 128 == 0) { - num_threads = 128; - } else if (input0_desc.dims.d[1] % 64 == 0) { - num_threads = 64; - } else if (input0_desc.dims.d[1] % 32 == 0) { - num_threads = 32; - } else if (input0_desc.dims.d[1] % 16 == 0) { - num_threads = 16; - } else if (input0_desc.dims.d[1] % 8 == 0) { - num_threads = 8; - } else if (input0_desc.dims.d[1] % 4 == 0) { - num_threads = 4; - } else if (input0_desc.dims.d[1] % 2 == 0) { - num_threads = 2; + if (vector_length < 1024) { + num_threads = vector_length; } else { - num_threads = 1; + if (vector_length % 512 == 0) { + num_threads = 512; + } else if (vector_length % 256 == 0) { + num_threads = 256; + } else if (vector_length % 128 == 0) { + num_threads = 128; + } else if (vector_length % 64 == 0) { + num_threads = 64; + } else if (vector_length % 32 == 0) { + num_threads = 32; + } else if (vector_length % 16 == 0) { + num_threads = 16; + } else if (vector_length % 8 == 0) { + num_threads = 8; + } else if (vector_length % 4 == 0) { + num_threads = 4; + } else if (vector_length % 2 == 0) { + num_threads = 2; + } else { + num_threads = 1; + } } const dim3 num_blocks( input1_desc.dims.d[0] - 1, input2_desc.dims.d[1], - input0_desc.dims.d[1] / num_threads); // batchs, max sequnce length - // (mask_id.dims.d[1]), - // input.dims.d[1]/256 + vector_length / num_threads); // batchs, max sequnce length + // (mask_id.dims.d[1]), + // input.dims.d[1]/*** RecoverPaddingKernel<<>>( input0, input1, output); return cudaGetLastError() != cudaSuccess; diff --git a/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu index 9f1a1d6d2c..a18c0d0c72 100644 --- a/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu @@ -19,9 +19,9 @@ namespace inference { namespace tensorrt { namespace plugin { -__global__ void RemovePaddingKernel(const float* input0, +__global__ void RemovePaddingKernel(const half* input0, const int32_t* input1, - float* output) { + half* output) { int word_id = blockIdx.x * gridDim.y + blockIdx.y; int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x]; if (blockIdx.y < seqence_length) { @@ -73,7 +73,7 @@ bool RemovePaddingPlugin::supportsFormatCombination( return inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; } else { - return inOut[pos].type == nvinfer1::DataType::kFLOAT && + return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; } // return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format @@ -106,38 +106,43 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, void* workspace, cudaStream_t stream) TRT_NOEXCEPT { const auto input_desc = inputDesc[0]; - const float* input0 = static_cast(inputs[0]); + const half* input0 = static_cast(inputs[0]); const int32_t* input1 = static_cast(inputs[1]); // pos_id_tensor - float* output = static_cast(outputs[0]); + half* output = static_cast(outputs[0]); const auto input0_desc = inputDesc[0]; + const int32_t vector_length = input0_desc.dims.d[2]; int32_t num_threads; - if (input0_desc.dims.d[2] % 512 == 0) { - num_threads = 512; - } else if (input0_desc.dims.d[2] % 256 == 0) { - num_threads = 256; - } else if (input0_desc.dims.d[2] % 128 == 0) { - num_threads = 128; - } else if (input0_desc.dims.d[2] % 64 == 0) { - num_threads = 64; - } else if (input0_desc.dims.d[2] % 32 == 0) { - num_threads = 32; - } else if (input0_desc.dims.d[2] % 16 == 0) { - num_threads = 16; - } else if (input0_desc.dims.d[2] % 8 == 0) { - num_threads = 8; - } else if (input0_desc.dims.d[2] % 4 == 0) { - num_threads = 4; - } else if (input0_desc.dims.d[2] % 2 == 0) { - num_threads = 2; + if (vector_length < 1024) { + num_threads = vector_length; } else { - num_threads = 1; + if (vector_length % 512 == 0) { + num_threads = 512; + } else if (vector_length % 256 == 0) { + num_threads = 256; + } else if (vector_length % 128 == 0) { + num_threads = 128; + } else if (vector_length % 64 == 0) { + num_threads = 64; + } else if (vector_length % 32 == 0) { + num_threads = 32; + } else if (vector_length % 16 == 0) { + num_threads = 16; + } else if (vector_length % 8 == 0) { + num_threads = 8; + } else if (vector_length % 4 == 0) { + num_threads = 4; + } else if (vector_length % 2 == 0) { + num_threads = 2; + } else { + num_threads = 1; + } } const dim3 num_blocks( input0_desc.dims.d[0], input0_desc.dims.d[1], - input0_desc.dims.d[2] / - num_threads); // batchs, max sequnce length, input.dims.d[2]/256 + vector_length / + num_threads); // batchs, max sequnce length, input0.dims.d[2]/*** RemovePaddingKernel<<>>( input0, input1, output); diff --git a/paddle/fluid/inference/tensorrt/plugin/test_fused_token_prune_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/test_fused_token_prune_plugin.cc index 4cc20c4365..543e7dca22 100644 --- a/paddle/fluid/inference/tensorrt/plugin/test_fused_token_prune_plugin.cc +++ b/paddle/fluid/inference/tensorrt/plugin/test_fused_token_prune_plugin.cc @@ -26,12 +26,9 @@ TEST(fused_token_prune_op_plugin, test_plugin) { /*keep_first_token*/ false, /*keep_order*/ true, /*flag_varseqlen*/ false); - plugin.configurePlugin(nullptr, 4, nullptr, 2); plugin.initialize(); plugin.getPluginType(); plugin.getNbOutputs(); - auto clone_plugin = plugin.clone(); - clone_plugin->destroy(); size_t buf_size = plugin.getSerializationSize(); std::vector buf(buf_size); plugin.serialize(buf.data()); -- GitLab