From 4ed6eeabb88c7301092d25b1dbc8d79defc1ce6a Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Fri, 23 Dec 2022 13:29:44 +0800 Subject: [PATCH] [Paddle Inference]add ouutput(CLSInds) for fused_token_prune (#49271) * add ouutput(CLSInds) for fused_token_prune --- .../plugin/fused_token_prune_op_plugin.cu | 19 ++++++++++++------ .../inference/tensorrt/test_dynamic_engine.cc | 20 ++++++++++++++++++- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu index f65d40a0ea4..ebc9f228026 100644 --- a/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu @@ -140,15 +140,17 @@ __global__ void prune_token_keep_order(const T* tokens, int32_t new_sequnce_length, const int32_t padding_token_length, const int32_t* token_index, - T* output) { + T* output0, + int32_t* output1) { int batch = blockIdx.x; int index = 0; for (int i = 0; i < pre_sequnce_length; ++i) { if (token_index[batch * padding_token_length + i] < new_sequnce_length) { - output[(batch * new_sequnce_length + index) * gridDim.y * blockDim.x + - blockIdx.y * blockDim.x + threadIdx.x] = + output0[(batch * new_sequnce_length + index) * gridDim.y * blockDim.x + + blockIdx.y * blockDim.x + threadIdx.x] = tokens[(batch * pre_sequnce_length + i) * gridDim.y * blockDim.x + blockIdx.y * blockDim.x + threadIdx.x]; + output1[batch * new_sequnce_length + index] = i; index++; } } @@ -273,7 +275,8 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination( const nvinfer1::PluginTensorDesc& prev = in_out[0]; return in.type == prev.type && in.format == prev.format; } else { - return in.format == nvinfer1::TensorFormat::kLINEAR; + return in.type == nvinfer1::DataType::kINT32 && + in.format == nvinfer1::TensorFormat::kLINEAR; } } } @@ -457,6 +460,7 @@ int FusedTokenPrunePluginDynamic::enqueue( const float* scores = static_cast(inputs[0]); // reduce sum const float* tokens = static_cast(inputs[1]); // X float* output0 = static_cast(outputs[0]); + int32_t* output1 = static_cast(outputs[1]); int32_t padding_token_length; if (pre_sequnce_length <= 64) { padding_token_length = 64; @@ -533,7 +537,8 @@ int FusedTokenPrunePluginDynamic::enqueue( new_sequnce_length, padding_token_length, token_index_, - output0); + output0, + output1); } else { const dim3 num_blocks(B, pre_sequnce_length, length / num_threads); prune_token_change_order @@ -548,6 +553,7 @@ int FusedTokenPrunePluginDynamic::enqueue( const half* scores = static_cast(inputs[0]); // reduce sum const half* tokens = static_cast(inputs[1]); // X half* output0 = static_cast(outputs[0]); + int32_t* output1 = static_cast(outputs[1]); int32_t padding_token_length; if (pre_sequnce_length <= 64) { padding_token_length = 64; @@ -624,7 +630,8 @@ int FusedTokenPrunePluginDynamic::enqueue( new_sequnce_length, padding_token_length, token_index_, - output0); + output0, + output1); } else { const dim3 num_blocks(B, pre_sequnce_length, length / num_threads); prune_token_change_order diff --git a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc index 46451c5db57..36d0f4b1d35 100644 --- a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc @@ -525,6 +525,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { ASSERT_EQ(slimmed_x_v[6], 2); ASSERT_EQ(slimmed_x_v[7], 1); + ASSERT_EQ(cls_inds_v[0], 2); + ASSERT_EQ(cls_inds_v[1], 3); + ASSERT_EQ(cls_inds_v[2], 2); + ASSERT_EQ(cls_inds_v[3], 3); + ASSERT_EQ(cls_inds_v[4], 2); + ASSERT_EQ(cls_inds_v[5], 3); + ASSERT_EQ(cls_inds_v[6], 2); + ASSERT_EQ(cls_inds_v[7], 3); + LOG(INFO) << "finish"; #endif } @@ -578,7 +587,7 @@ class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test { std::map>(), std::map>(), false, - phi::DataType::FLOAT16, + phi::DataType::FLOAT32, NaiveLogger::Global()); engine_->InitNetwork(); } @@ -724,6 +733,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) { ASSERT_EQ(slimmed_x_v[6], 2); ASSERT_EQ(slimmed_x_v[7], 1); + ASSERT_EQ(cls_inds_v[0], 2); + ASSERT_EQ(cls_inds_v[1], 3); + ASSERT_EQ(cls_inds_v[2], 2); + ASSERT_EQ(cls_inds_v[3], 3); + ASSERT_EQ(cls_inds_v[4], 2); + ASSERT_EQ(cls_inds_v[5], 3); + ASSERT_EQ(cls_inds_v[6], 2); + ASSERT_EQ(cls_inds_v[7], 3); + LOG(INFO) << "finish"; #endif } -- GitLab