Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4ed6eeab
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4ed6eeab
编写于
12月 23, 2022
作者:
W
Wangzheee
提交者:
GitHub
12月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference]add ouutput(CLSInds) for fused_token_prune (#49271)
* add ouutput(CLSInds) for fused_token_prune
上级
80d465ee
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
32 addition
and
7 deletion
+32
-7
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
.../inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
+13
-6
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
+19
-1
未找到文件。
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
浏览文件 @
4ed6eeab
...
...
@@ -140,15 +140,17 @@ __global__ void prune_token_keep_order(const T* tokens,
int32_t
new_sequnce_length
,
const
int32_t
padding_token_length
,
const
int32_t
*
token_index
,
T
*
output
)
{
T
*
output0
,
int32_t
*
output1
)
{
int
batch
=
blockIdx
.
x
;
int
index
=
0
;
for
(
int
i
=
0
;
i
<
pre_sequnce_length
;
++
i
)
{
if
(
token_index
[
batch
*
padding_token_length
+
i
]
<
new_sequnce_length
)
{
output
[(
batch
*
new_sequnce_length
+
index
)
*
gridDim
.
y
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
]
=
output
0
[(
batch
*
new_sequnce_length
+
index
)
*
gridDim
.
y
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
]
=
tokens
[(
batch
*
pre_sequnce_length
+
i
)
*
gridDim
.
y
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
];
output1
[
batch
*
new_sequnce_length
+
index
]
=
i
;
index
++
;
}
}
...
...
@@ -273,7 +275,8 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
0
];
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
else
{
return
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
return
in
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
}
}
...
...
@@ -457,6 +460,7 @@ int FusedTokenPrunePluginDynamic::enqueue(
const
float
*
scores
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
// reduce sum
const
float
*
tokens
=
static_cast
<
const
float
*>
(
inputs
[
1
]);
// X
float
*
output0
=
static_cast
<
float
*>
(
outputs
[
0
]);
int32_t
*
output1
=
static_cast
<
int32_t
*>
(
outputs
[
1
]);
int32_t
padding_token_length
;
if
(
pre_sequnce_length
<=
64
)
{
padding_token_length
=
64
;
...
...
@@ -533,7 +537,8 @@ int FusedTokenPrunePluginDynamic::enqueue(
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
output0
,
output1
);
}
else
{
const
dim3
num_blocks
(
B
,
pre_sequnce_length
,
length
/
num_threads
);
prune_token_change_order
<
float
>
...
...
@@ -548,6 +553,7 @@ int FusedTokenPrunePluginDynamic::enqueue(
const
half
*
scores
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
// reduce sum
const
half
*
tokens
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
// X
half
*
output0
=
static_cast
<
half
*>
(
outputs
[
0
]);
int32_t
*
output1
=
static_cast
<
int32_t
*>
(
outputs
[
1
]);
int32_t
padding_token_length
;
if
(
pre_sequnce_length
<=
64
)
{
padding_token_length
=
64
;
...
...
@@ -624,7 +630,8 @@ int FusedTokenPrunePluginDynamic::enqueue(
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
output0
,
output1
);
}
else
{
const
dim3
num_blocks
(
B
,
pre_sequnce_length
,
length
/
num_threads
);
prune_token_change_order
<
half
>
...
...
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
浏览文件 @
4ed6eeab
...
...
@@ -525,6 +525,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
ASSERT_EQ
(
slimmed_x_v
[
6
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
7
],
1
);
ASSERT_EQ
(
cls_inds_v
[
0
],
2
);
ASSERT_EQ
(
cls_inds_v
[
1
],
3
);
ASSERT_EQ
(
cls_inds_v
[
2
],
2
);
ASSERT_EQ
(
cls_inds_v
[
3
],
3
);
ASSERT_EQ
(
cls_inds_v
[
4
],
2
);
ASSERT_EQ
(
cls_inds_v
[
5
],
3
);
ASSERT_EQ
(
cls_inds_v
[
6
],
2
);
ASSERT_EQ
(
cls_inds_v
[
7
],
3
);
LOG
(
INFO
)
<<
"finish"
;
#endif
}
...
...
@@ -578,7 +587,7 @@ class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test {
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(),
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(),
false
,
phi
::
DataType
::
FLOAT
16
,
phi
::
DataType
::
FLOAT
32
,
NaiveLogger
::
Global
());
engine_
->
InitNetwork
();
}
...
...
@@ -724,6 +733,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPruneHalf, test_fused_token_prune) {
ASSERT_EQ
(
slimmed_x_v
[
6
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
7
],
1
);
ASSERT_EQ
(
cls_inds_v
[
0
],
2
);
ASSERT_EQ
(
cls_inds_v
[
1
],
3
);
ASSERT_EQ
(
cls_inds_v
[
2
],
2
);
ASSERT_EQ
(
cls_inds_v
[
3
],
3
);
ASSERT_EQ
(
cls_inds_v
[
4
],
2
);
ASSERT_EQ
(
cls_inds_v
[
5
],
3
);
ASSERT_EQ
(
cls_inds_v
[
6
],
2
);
ASSERT_EQ
(
cls_inds_v
[
7
],
3
);
LOG
(
INFO
)
<<
"finish"
;
#endif
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录