Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4ed6eeab
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4ed6eeab
编写于
12月 23, 2022
作者:
W
Wangzheee
提交者:
GitHub
12月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference]add ouutput(CLSInds) for fused_token_prune (#49271)
* add ouutput(CLSInds) for fused_token_prune
上级
80d465ee
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
32 addition
and
7 deletion
+32
-7
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
.../inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
+13
-6
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
+19
-1
未找到文件。
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
浏览文件 @
4ed6eeab
...
...
@@ -140,15 +140,17 @@ __global__ void prune_token_keep_order(const T* tokens,
int32_t
new_sequnce_length
,
const
int32_t
padding_token_length
,
const
int32_t
*
token_index
,
T
*
output
)
{
T
*
output0
,
int32_t
*
output1
)
{
int
batch
=
blockIdx
.
x
;
int
index
=
0
;
for
(
int
i
=
0
;
i
<
pre_sequnce_length
;
++
i
)
{
if
(
token_index
[
batch
*
padding_token_length
+
i
]
<
new_sequnce_length
)
{
output
[(
batch
*
new_sequnce_length
+
index
)
*
gridDim
.
y
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
]
=
output
0
[(
batch
*
new_sequnce_length
+
index
)
*
gridDim
.
y
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
]
=
tokens
[(
batch
*
pre_sequnce_length
+
i
)
*
gridDim
.
y
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
];
output1
[
batch
*
new_sequnce_length
+
index
]
=
i
;
index
++
;
}
}
...
...
@@ -273,7 +275,8 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
0
];
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
else
{
return
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
return
in
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
}
}
...
...
@@ -457,6 +460,7 @@ int FusedTokenPrunePluginDynamic::enqueue(
const
float
*
scores
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
// reduce sum
const
float
*
tokens
=
static_cast
<
const
float
*>
(
inputs
[
1
]);
// X
float
*
output0
=
static_cast
<
float
*>
(
outputs
[
0
]);
int32_t
*
output1
=
static_cast
<
int32_t
*>
(
outputs
[
1
]);
int32_t
padding_token_length
;
if
(
pre_sequnce_length
<=
64
)
{
padding_token_length
=
64
;
...
...
@@ -533,7 +537,8 @@ int FusedTokenPrunePluginDynamic::enqueue(
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
output0
,
output1
);
}
else
{
const
dim3
num_blocks
(
B
,
pre_sequnce_length
,
length
/
num_threads
);
prune_token_change_order
<
float
>
...
...
@@ -548,6 +553,7 @@ int FusedTokenPrunePluginDynamic::enqueue(
const
half
*
scores
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
// reduce sum
const
half
*
tokens
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
// X
half
*
output0
=
static_cast
<
half
*>
(
outputs
[
0
]);
int32_t
*
output1
=
static_cast
<
int32_t
*>
(
outputs
[
1
]);
int32_t
padding_token_length
;
if
(
pre_sequnce_length
<=
64
)
{
padding_token_length
=
64
;
...
...
@@ -624,7 +630,8 @@ int FusedTokenPrunePluginDynamic::enqueue(
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
output0
,
output1
);
}
else
{
const
dim3
num_blocks
(
B
,
pre_sequnce_length
,
length
/
num_threads
);
prune_token_change_order
<
half
>
...
...
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
浏览文件 @
4ed6eeab
...
...
@@ -525,6 +525,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
ASSERT_EQ
(
slimmed_x_v
[
6
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
7
],
1
);
ASSERT_EQ
(
cls_inds_v
[
0
],
2
);
ASSERT_EQ
(
cls_inds_v
[
1
],
3
);
ASSERT_EQ
(
cls_inds_v
[
2
],
2
);
ASSERT_EQ
(
cls_inds_v
[
3
],
3
);
ASSERT_EQ
(
cls_inds_v
[
4
],
2
);
ASSERT_EQ
(
cls_inds_v
[
5
],
3
);
ASSERT_EQ
(
cls_inds_v
[
6
],
2
);
ASSERT_EQ
(
cls_inds_v
[
7
],
3
);
LOG
(
INFO
)
<<
"finish"
;
#endif
}
...
...
@@ -578,7 +587,7 @@ class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test {
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(),
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(),
false
,
phi
::
DataType
::
FLOAT
16
,
phi
::
DataType
::
FLOAT
32
,
NaiveLogger
::
Global
());
engine_
->
InitNetwork
();
}
...
...
@@ -724,6 +733,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) {
ASSERT_EQ
(
slimmed_x_v
[
6
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
7
],
1
);
ASSERT_EQ
(
cls_inds_v
[
0
],
2
);
ASSERT_EQ
(
cls_inds_v
[
1
],
3
);
ASSERT_EQ
(
cls_inds_v
[
2
],
2
);
ASSERT_EQ
(
cls_inds_v
[
3
],
3
);
ASSERT_EQ
(
cls_inds_v
[
4
],
2
);
ASSERT_EQ
(
cls_inds_v
[
5
],
3
);
ASSERT_EQ
(
cls_inds_v
[
6
],
2
);
ASSERT_EQ
(
cls_inds_v
[
7
],
3
);
LOG
(
INFO
)
<<
"finish"
;
#endif
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录