未验证 提交 4ed6eeab 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]add ouutput(CLSInds) for fused_token_prune (#49271)

* add ouutput(CLSInds) for fused_token_prune
上级 80d465ee
......@@ -140,15 +140,17 @@ __global__ void prune_token_keep_order(const T* tokens,
int32_t new_sequnce_length,
const int32_t padding_token_length,
const int32_t* token_index,
T* output) {
T* output0,
int32_t* output1) {
int batch = blockIdx.x;
int index = 0;
for (int i = 0; i < pre_sequnce_length; ++i) {
if (token_index[batch * padding_token_length + i] < new_sequnce_length) {
output[(batch * new_sequnce_length + index) * gridDim.y * blockDim.x +
output0[(batch * new_sequnce_length + index) * gridDim.y * blockDim.x +
blockIdx.y * blockDim.x + threadIdx.x] =
tokens[(batch * pre_sequnce_length + i) * gridDim.y * blockDim.x +
blockIdx.y * blockDim.x + threadIdx.x];
output1[batch * new_sequnce_length + index] = i;
index++;
}
}
......@@ -273,7 +275,8 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
const nvinfer1::PluginTensorDesc& prev = in_out[0];
return in.type == prev.type && in.format == prev.format;
} else {
return in.format == nvinfer1::TensorFormat::kLINEAR;
return in.type == nvinfer1::DataType::kINT32 &&
in.format == nvinfer1::TensorFormat::kLINEAR;
}
}
}
......@@ -457,6 +460,7 @@ int FusedTokenPrunePluginDynamic::enqueue(
const float* scores = static_cast<const float*>(inputs[0]); // reduce sum
const float* tokens = static_cast<const float*>(inputs[1]); // X
float* output0 = static_cast<float*>(outputs[0]);
int32_t* output1 = static_cast<int32_t*>(outputs[1]);
int32_t padding_token_length;
if (pre_sequnce_length <= 64) {
padding_token_length = 64;
......@@ -533,7 +537,8 @@ int FusedTokenPrunePluginDynamic::enqueue(
new_sequnce_length,
padding_token_length,
token_index_,
output0);
output0,
output1);
} else {
const dim3 num_blocks(B, pre_sequnce_length, length / num_threads);
prune_token_change_order<float>
......@@ -548,6 +553,7 @@ int FusedTokenPrunePluginDynamic::enqueue(
const half* scores = static_cast<const half*>(inputs[0]); // reduce sum
const half* tokens = static_cast<const half*>(inputs[1]); // X
half* output0 = static_cast<half*>(outputs[0]);
int32_t* output1 = static_cast<int32_t*>(outputs[1]);
int32_t padding_token_length;
if (pre_sequnce_length <= 64) {
padding_token_length = 64;
......@@ -624,7 +630,8 @@ int FusedTokenPrunePluginDynamic::enqueue(
new_sequnce_length,
padding_token_length,
token_index_,
output0);
output0,
output1);
} else {
const dim3 num_blocks(B, pre_sequnce_length, length / num_threads);
prune_token_change_order<half>
......
......@@ -525,6 +525,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
ASSERT_EQ(slimmed_x_v[6], 2);
ASSERT_EQ(slimmed_x_v[7], 1);
ASSERT_EQ(cls_inds_v[0], 2);
ASSERT_EQ(cls_inds_v[1], 3);
ASSERT_EQ(cls_inds_v[2], 2);
ASSERT_EQ(cls_inds_v[3], 3);
ASSERT_EQ(cls_inds_v[4], 2);
ASSERT_EQ(cls_inds_v[5], 3);
ASSERT_EQ(cls_inds_v[6], 2);
ASSERT_EQ(cls_inds_v[7], 3);
LOG(INFO) << "finish";
#endif
}
......@@ -578,7 +587,7 @@ class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test {
std::map<std::string, std::vector<int>>(),
std::map<std::string, std::vector<int>>(),
false,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
NaiveLogger::Global());
engine_->InitNetwork();
}
......@@ -724,6 +733,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) {
ASSERT_EQ(slimmed_x_v[6], 2);
ASSERT_EQ(slimmed_x_v[7], 1);
ASSERT_EQ(cls_inds_v[0], 2);
ASSERT_EQ(cls_inds_v[1], 3);
ASSERT_EQ(cls_inds_v[2], 2);
ASSERT_EQ(cls_inds_v[3], 3);
ASSERT_EQ(cls_inds_v[4], 2);
ASSERT_EQ(cls_inds_v[5], 3);
ASSERT_EQ(cls_inds_v[6], 2);
ASSERT_EQ(cls_inds_v[7], 3);
LOG(INFO) << "finish";
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册