未验证 提交 4ed6eeab 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]add ouutput(CLSInds) for fused_token_prune (#49271)

* add ouutput(CLSInds) for fused_token_prune
上级 80d465ee
...@@ -140,15 +140,17 @@ __global__ void prune_token_keep_order(const T* tokens, ...@@ -140,15 +140,17 @@ __global__ void prune_token_keep_order(const T* tokens,
int32_t new_sequnce_length, int32_t new_sequnce_length,
const int32_t padding_token_length, const int32_t padding_token_length,
const int32_t* token_index, const int32_t* token_index,
T* output) { T* output0,
int32_t* output1) {
int batch = blockIdx.x; int batch = blockIdx.x;
int index = 0; int index = 0;
for (int i = 0; i < pre_sequnce_length; ++i) { for (int i = 0; i < pre_sequnce_length; ++i) {
if (token_index[batch * padding_token_length + i] < new_sequnce_length) { if (token_index[batch * padding_token_length + i] < new_sequnce_length) {
output[(batch * new_sequnce_length + index) * gridDim.y * blockDim.x + output0[(batch * new_sequnce_length + index) * gridDim.y * blockDim.x +
blockIdx.y * blockDim.x + threadIdx.x] = blockIdx.y * blockDim.x + threadIdx.x] =
tokens[(batch * pre_sequnce_length + i) * gridDim.y * blockDim.x + tokens[(batch * pre_sequnce_length + i) * gridDim.y * blockDim.x +
blockIdx.y * blockDim.x + threadIdx.x]; blockIdx.y * blockDim.x + threadIdx.x];
output1[batch * new_sequnce_length + index] = i;
index++; index++;
} }
} }
...@@ -273,7 +275,8 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination( ...@@ -273,7 +275,8 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
const nvinfer1::PluginTensorDesc& prev = in_out[0]; const nvinfer1::PluginTensorDesc& prev = in_out[0];
return in.type == prev.type && in.format == prev.format; return in.type == prev.type && in.format == prev.format;
} else { } else {
return in.format == nvinfer1::TensorFormat::kLINEAR; return in.type == nvinfer1::DataType::kINT32 &&
in.format == nvinfer1::TensorFormat::kLINEAR;
} }
} }
} }
...@@ -457,6 +460,7 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -457,6 +460,7 @@ int FusedTokenPrunePluginDynamic::enqueue(
const float* scores = static_cast<const float*>(inputs[0]); // reduce sum const float* scores = static_cast<const float*>(inputs[0]); // reduce sum
const float* tokens = static_cast<const float*>(inputs[1]); // X const float* tokens = static_cast<const float*>(inputs[1]); // X
float* output0 = static_cast<float*>(outputs[0]); float* output0 = static_cast<float*>(outputs[0]);
int32_t* output1 = static_cast<int32_t*>(outputs[1]);
int32_t padding_token_length; int32_t padding_token_length;
if (pre_sequnce_length <= 64) { if (pre_sequnce_length <= 64) {
padding_token_length = 64; padding_token_length = 64;
...@@ -533,7 +537,8 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -533,7 +537,8 @@ int FusedTokenPrunePluginDynamic::enqueue(
new_sequnce_length, new_sequnce_length,
padding_token_length, padding_token_length,
token_index_, token_index_,
output0); output0,
output1);
} else { } else {
const dim3 num_blocks(B, pre_sequnce_length, length / num_threads); const dim3 num_blocks(B, pre_sequnce_length, length / num_threads);
prune_token_change_order<float> prune_token_change_order<float>
...@@ -548,6 +553,7 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -548,6 +553,7 @@ int FusedTokenPrunePluginDynamic::enqueue(
const half* scores = static_cast<const half*>(inputs[0]); // reduce sum const half* scores = static_cast<const half*>(inputs[0]); // reduce sum
const half* tokens = static_cast<const half*>(inputs[1]); // X const half* tokens = static_cast<const half*>(inputs[1]); // X
half* output0 = static_cast<half*>(outputs[0]); half* output0 = static_cast<half*>(outputs[0]);
int32_t* output1 = static_cast<int32_t*>(outputs[1]);
int32_t padding_token_length; int32_t padding_token_length;
if (pre_sequnce_length <= 64) { if (pre_sequnce_length <= 64) {
padding_token_length = 64; padding_token_length = 64;
...@@ -624,7 +630,8 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -624,7 +630,8 @@ int FusedTokenPrunePluginDynamic::enqueue(
new_sequnce_length, new_sequnce_length,
padding_token_length, padding_token_length,
token_index_, token_index_,
output0); output0,
output1);
} else { } else {
const dim3 num_blocks(B, pre_sequnce_length, length / num_threads); const dim3 num_blocks(B, pre_sequnce_length, length / num_threads);
prune_token_change_order<half> prune_token_change_order<half>
......
...@@ -525,6 +525,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) { ...@@ -525,6 +525,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
ASSERT_EQ(slimmed_x_v[6], 2); ASSERT_EQ(slimmed_x_v[6], 2);
ASSERT_EQ(slimmed_x_v[7], 1); ASSERT_EQ(slimmed_x_v[7], 1);
ASSERT_EQ(cls_inds_v[0], 2);
ASSERT_EQ(cls_inds_v[1], 3);
ASSERT_EQ(cls_inds_v[2], 2);
ASSERT_EQ(cls_inds_v[3], 3);
ASSERT_EQ(cls_inds_v[4], 2);
ASSERT_EQ(cls_inds_v[5], 3);
ASSERT_EQ(cls_inds_v[6], 2);
ASSERT_EQ(cls_inds_v[7], 3);
LOG(INFO) << "finish"; LOG(INFO) << "finish";
#endif #endif
} }
...@@ -578,7 +587,7 @@ class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test { ...@@ -578,7 +587,7 @@ class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test {
std::map<std::string, std::vector<int>>(), std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(), std::map<std::string, std::vector<int>>(),
false, false,
phi::DataType::FLOAT16, phi::DataType::FLOAT32,
NaiveLogger::Global()); NaiveLogger::Global());
engine_->InitNetwork(); engine_->InitNetwork();
} }
...@@ -724,6 +733,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) { ...@@ -724,6 +733,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) {
ASSERT_EQ(slimmed_x_v[6], 2); ASSERT_EQ(slimmed_x_v[6], 2);
ASSERT_EQ(slimmed_x_v[7], 1); ASSERT_EQ(slimmed_x_v[7], 1);
ASSERT_EQ(cls_inds_v[0], 2);
ASSERT_EQ(cls_inds_v[1], 3);
ASSERT_EQ(cls_inds_v[2], 2);
ASSERT_EQ(cls_inds_v[3], 3);
ASSERT_EQ(cls_inds_v[4], 2);
ASSERT_EQ(cls_inds_v[5], 3);
ASSERT_EQ(cls_inds_v[6], 2);
ASSERT_EQ(cls_inds_v[7], 3);
LOG(INFO) << "finish"; LOG(INFO) << "finish";
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册