Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
65c17315
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
65c17315
编写于
12月 21, 2022
作者:
W
Wangzheee
提交者:
GitHub
12月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference]optimize token prune for no varlen (#49094)
* optimize token prune for no varlen
上级
4cdeab7b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
546 addition
and
450 deletion
+546
-450
paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc
.../fluid/inference/tensorrt/convert/fused_token_prune_op.cc
+18
-15
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
.../inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
+281
-410
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h
...d/inference/tensorrt/plugin/fused_token_prune_op_plugin.h
+24
-3
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
+223
-22
未找到文件。
paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc
浏览文件 @
65c17315
...
@@ -38,6 +38,17 @@ class FusedTokenPruneOpConverter : public OpConverter {
...
@@ -38,6 +38,17 @@ class FusedTokenPruneOpConverter : public OpConverter {
auto
output_name
=
op_desc
.
Output
(
"SlimmedX"
)[
0
];
auto
output_name
=
op_desc
.
Output
(
"SlimmedX"
)[
0
];
auto
out_inds_name
=
op_desc
.
Output
(
"CLSInds"
)[
0
];
auto
out_inds_name
=
op_desc
.
Output
(
"CLSInds"
)[
0
];
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
with_dynamic_shape
())
{
// reduce_sum: (-1,headsize,token_length,token_length) ->
// (-1,token_length)
uint32_t
reduce_dim
=
0
;
reduce_dim
|=
1
<<
1
;
// 00000000000000000000000000000010
reduce_dim
|=
1
<<
2
;
// 00000000000000000000000000000110
bool
keep_dim
=
false
;
nvinfer1
::
ReduceOperation
reduce_type
=
nvinfer1
::
ReduceOperation
::
kSUM
;
auto
*
reduce_sum_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
Attn
,
reduce_type
,
reduce_dim
,
keep_dim
);
auto
*
Reduced
=
reduce_sum_layer
->
getOutput
(
0
);
bool
with_fp16
=
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
...
@@ -53,21 +64,10 @@ class FusedTokenPruneOpConverter : public OpConverter {
...
@@ -53,21 +64,10 @@ class FusedTokenPruneOpConverter : public OpConverter {
auto
*
pos_id
=
engine_
->
GetITensor
(
"pos_id"
);
auto
*
pos_id
=
engine_
->
GetITensor
(
"pos_id"
);
auto
*
mask_id
=
engine_
->
GetITensor
(
"mask_id"
);
auto
*
mask_id
=
engine_
->
GetITensor
(
"mask_id"
);
// reduce_sum: (-1,headsize,token_length,token_length) ->
// (-1,token_length)
uint32_t
reduce_dim
=
0
;
reduce_dim
|=
1
<<
1
;
// 00000000000000000000000000000010
reduce_dim
|=
1
<<
2
;
// 00000000000000000000000000000110
bool
keep_dim
=
false
;
nvinfer1
::
ReduceOperation
reduce_type
=
nvinfer1
::
ReduceOperation
::
kSUM
;
auto
*
reduce_sum_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
Attn
,
reduce_type
,
reduce_dim
,
keep_dim
);
// reduce_sum_layer->getOutput(0)->setType(reduce_sum_layer->getInput(0)->getType());
auto
*
Reduced
=
reduce_sum_layer
->
getOutput
(
0
);
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
=
{
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
=
{
Reduced
,
X
,
Mask
,
NewMask
,
word_id
,
pos_id
,
mask_id
};
Reduced
,
X
,
Mask
,
NewMask
,
word_id
,
pos_id
,
mask_id
};
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
7
,
plugin
);
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
itensors
.
size
(),
plugin
);
// inputs'number: 7
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
...
@@ -87,10 +87,13 @@ class FusedTokenPruneOpConverter : public OpConverter {
...
@@ -87,10 +87,13 @@ class FusedTokenPruneOpConverter : public OpConverter {
layer
->
getOutput
(
4
)
->
setName
(
"mask_id_after_token_prune"
);
layer
->
getOutput
(
4
)
->
setName
(
"mask_id_after_token_prune"
);
engine_
->
SetITensor
(
"mask_id"
,
layer
->
getOutput
(
4
));
engine_
->
SetITensor
(
"mask_id"
,
layer
->
getOutput
(
4
));
}
else
{
}
else
{
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
=
{
Attn
,
X
,
Mask
,
NewMask
};
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
=
{
Reduced
,
X
,
Mask
,
NewMask
};
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
4
,
plugin
);
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
itensors
.
size
(),
plugin
);
// inputs'number: 4
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
layer
->
getOutput
(
1
)
->
setName
(
out_inds_name
.
c_str
());
layer
->
getOutput
(
1
)
->
setName
(
out_inds_name
.
c_str
());
engine_
->
SetITensor
(
out_inds_name
,
layer
->
getOutput
(
1
));
engine_
->
SetITensor
(
out_inds_name
,
layer
->
getOutput
(
1
));
}
}
...
...
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
浏览文件 @
65c17315
...
@@ -31,150 +31,6 @@ namespace inference {
...
@@ -31,150 +31,6 @@ namespace inference {
namespace
tensorrt
{
namespace
tensorrt
{
namespace
plugin
{
namespace
plugin
{
template
<
typename
T
>
__global__
void
ElementwiseMask
(
const
T
*
a
,
const
T
*
b
,
T
*
res
,
int
num_elements
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
auto
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
num_elements
)
return
;
const
T
zero
=
0
;
res
[
tid
]
=
b
[
tid
]
>=
zero
?
a
[
tid
]
:
zero
;
#endif
}
template
<
typename
T
>
__global__
void
FillZero
(
T
*
data
,
int
len
)
{
auto
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
len
)
return
;
const
T
zero
=
0
;
data
[
tid
]
=
zero
;
}
__global__
void
FillIndex
(
int32_t
*
indices
,
int
num_raws
,
int
num_cols
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
num_raws
*
num_cols
)
return
;
int
col
=
tid
%
num_cols
;
int
raw
=
tid
/
num_cols
;
indices
[
tid
]
=
col
;
}
template
<
typename
T
>
__global__
void
MaximumFirst
(
T
*
mat
,
int
num_raws
,
int
num_cols
,
T
max_value
)
{
auto
raw
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
raw
>=
num_raws
)
return
;
mat
[
raw
*
num_cols
]
=
max_value
;
}
__global__
void
FillOffsets
(
int
*
offsets
,
int
num_raws
,
int
num_cols
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>
num_raws
)
return
;
offsets
[
tid
]
=
tid
*
num_cols
;
}
template
<
typename
T
>
__global__
void
Slice
(
const
T
*
src
,
T
*
dst
,
int
num_raws
,
int
src_num_cols
,
int
dst_num_cols
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
num_raws
*
dst_num_cols
)
return
;
int
raw
=
tid
/
dst_num_cols
;
int
col
=
tid
%
dst_num_cols
;
dst
[
tid
]
=
src
[
raw
*
src_num_cols
+
col
];
}
template
<
typename
T
>
__global__
void
ReduceSum2
(
const
T
*
src
,
T
*
dst
,
int
bsz
,
int
nb_head
,
int
max_seq_len
)
{
int
tid
=
threadIdx
.
x
;
int
bid
=
blockIdx
.
x
;
int
num_blocks_per_head
=
((
max_seq_len
/
blockDim
.
x
)
*
max_seq_len
);
int
batch
=
bid
/
(
nb_head
*
num_blocks_per_head
);
int
col
=
bid
%
max_seq_len
;
int
head
=
(
bid
/
num_blocks_per_head
)
%
nb_head
;
extern
__shared__
T
res_float
[];
res_float
[
tid
]
=
src
[
batch
*
(
nb_head
*
max_seq_len
*
max_seq_len
)
+
head
*
(
max_seq_len
*
max_seq_len
)
+
col
+
tid
*
max_seq_len
];
__syncthreads
();
for
(
int
offset
=
blockDim
.
x
>>
1
;
offset
>
0
;
offset
>>=
1
)
{
if
(
tid
<
offset
)
{
res_float
[
tid
]
+=
res_float
[
tid
+
offset
];
}
__syncthreads
();
if
(
offset
%
2
==
1
&&
tid
==
offset
-
2
)
{
res_float
[
tid
]
+=
res_float
[
tid
+
1
];
}
}
if
(
tid
==
0
)
{
auto
*
dst_addr
=
dst
+
batch
*
max_seq_len
+
col
;
atomicAdd
(
dst_addr
,
res_float
[
0
]);
}
}
template
<
>
__global__
void
ReduceSum2
<
half
>
(
const
half
*
src
,
half
*
dst
,
int
bsz
,
int
nb_head
,
int
max_seq_len
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int
tid
=
threadIdx
.
x
;
int
bid
=
blockIdx
.
x
;
int
num_blocks_per_head
=
((
max_seq_len
/
blockDim
.
x
)
*
max_seq_len
);
int
batch
=
bid
/
(
nb_head
*
num_blocks_per_head
);
int
col
=
bid
%
max_seq_len
;
int
head
=
(
bid
/
num_blocks_per_head
)
%
nb_head
;
extern
__shared__
half
res_half
[];
res_half
[
tid
]
=
src
[
batch
*
(
nb_head
*
max_seq_len
*
max_seq_len
)
+
head
*
(
max_seq_len
*
max_seq_len
)
+
col
+
tid
*
max_seq_len
];
__syncthreads
();
for
(
int
offset
=
blockDim
.
x
>>
1
;
offset
>
0
;
offset
>>=
1
)
{
if
(
tid
<
offset
)
{
res_half
[
tid
]
+=
res_half
[
tid
+
offset
];
}
__syncthreads
();
if
(
offset
%
2
==
1
&&
tid
==
offset
-
2
)
{
res_half
[
tid
]
+=
res_half
[
tid
+
1
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
phi
::
fastAtomicAdd
<
platform
::
float16
>
(
reinterpret_cast
<
platform
::
float16
*>
(
dst
),
static_cast
<
size_t
>
(
batch
*
max_seq_len
+
col
),
static_cast
<
size_t
>
(
bsz
*
max_seq_len
),
static_cast
<
platform
::
float16
>
(
res_half
[
0
]));
}
#endif
}
template
<
typename
T
>
__global__
void
TakeAlongAxis
(
const
T
*
src
,
T
*
dst
,
int32_t
*
indices
,
int
num_raws
,
int
src_num_cols
,
int
dst_num_cols
,
int
num_elements
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
num_raws
*
dst_num_cols
)
return
;
int
raw
=
tid
/
dst_num_cols
;
int
col
=
tid
%
dst_num_cols
;
for
(
int
i
=
0
;
i
<
num_elements
;
++
i
)
{
dst
[
tid
*
num_elements
+
i
]
=
*
(
src
+
(
raw
*
src_num_cols
+
indices
[
tid
])
*
num_elements
+
i
);
}
}
__global__
void
compute_token_length
(
const
int32_t
*
src
,
__global__
void
compute_token_length
(
const
int32_t
*
src
,
int32_t
*
dst
,
int32_t
*
dst
,
float
scale
)
{
float
scale
)
{
...
@@ -182,16 +38,18 @@ __global__ void compute_token_length(const int32_t* src,
...
@@ -182,16 +38,18 @@ __global__ void compute_token_length(const int32_t* src,
dst
[
it
]
=
max
(
static_cast
<
int
>
((
src
[
it
+
1
]
-
src
[
it
])
*
scale
),
1
);
dst
[
it
]
=
max
(
static_cast
<
int
>
((
src
[
it
+
1
]
-
src
[
it
])
*
scale
),
1
);
}
}
template
<
typename
T
>
__global__
void
fill_index_padding_score
(
int32_t
*
token_index
,
__global__
void
fill_index_padding_score
(
int32_t
*
token_index
,
const
half
*
scores
,
const
T
*
scores
,
int32_t
scores_size
,
int32_t
sequnce_length
,
half
*
padding_scores
)
{
T
*
padding_scores
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
padding_scores_it
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
token_index
[
tid
]
=
threadIdx
.
x
;
int
scores_it
=
threadIdx
.
x
+
blockIdx
.
x
*
sequnce_length
;
if
(
tid
<
scores_size
)
{
token_index
[
padding_scores_it
]
=
threadIdx
.
x
;
padding_scores
[
tid
]
=
scores
[
tid
];
if
(
threadIdx
.
x
<
sequnce_length
)
{
padding_scores
[
padding_scores_it
]
=
scores
[
scores_it
];
}
else
{
}
else
{
padding_scores
[
tid
]
=
0
;
padding_scores
[
padding_scores_it
]
=
0
;
}
}
}
}
...
@@ -238,21 +96,64 @@ __global__ void general_topk_pair_sort(T* in_keys, int32_t* in_out_values) {
...
@@ -238,21 +96,64 @@ __global__ void general_topk_pair_sort(T* in_keys, int32_t* in_out_values) {
.
Store
(
in_out_values
+
block_offset
,
thread_values
);
.
Store
(
in_out_values
+
block_offset
,
thread_values
);
}
}
__global__
void
varlen_prune_token
(
const
half
*
tokens
,
__global__
void
varlen_prune_token_change_order
(
const
int32_t
*
token_pos
,
const
half
*
tokens
,
const
int32_t
*
token_index
,
const
int32_t
*
token_pos
,
half
*
output
)
{
const
int32_t
padding_token_length
,
const
int32_t
*
token_index
,
half
*
output
)
{
int
batch
=
blockIdx
.
x
;
int
batch
=
blockIdx
.
x
;
int
token_it
=
batch
*
gridDim
.
y
+
blockIdx
.
y
;
int
token_it
=
batch
*
gridDim
.
y
+
blockIdx
.
y
;
int
pre_value_it
=
int
pre_value_it
=
token_it
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
;
token_it
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
;
int
token_index_it
=
batch
*
padding_token_length
+
blockIdx
.
y
;
if
(
token_index
[
token_it
]
<
token_pos
[
batch
+
1
]
-
token_pos
[
batch
])
{
if
(
token_index
[
token_index_it
]
<
token_pos
[
batch
+
1
]
-
token_pos
[
batch
])
{
output
[(
token_index
[
token_it
]
+
token_pos
[
batch
])
*
gridDim
.
z
*
blockDim
.
x
+
output
[(
token_index
[
token_index_it
]
+
token_pos
[
batch
])
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
tokens
[
pre_value_it
];
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
tokens
[
pre_value_it
];
}
}
}
}
template
<
typename
T
>
__global__
void
prune_token_change_order
(
const
T
*
tokens
,
int32_t
new_sequnce_length
,
const
int32_t
padding_token_length
,
const
int32_t
*
token_index
,
T
*
output
)
{
int
batch
=
blockIdx
.
x
;
int
token_it
=
batch
*
gridDim
.
y
+
blockIdx
.
y
;
int
pre_value_it
=
token_it
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
;
int
token_index_it
=
batch
*
padding_token_length
+
blockIdx
.
y
;
if
(
token_index
[
token_index_it
]
<
new_sequnce_length
)
{
output
[(
batch
*
new_sequnce_length
+
token_index
[
token_index_it
])
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
tokens
[
pre_value_it
];
}
}
template
<
typename
T
>
__global__
void
prune_token_keep_order
(
const
T
*
tokens
,
int32_t
pre_sequnce_length
,
int32_t
new_sequnce_length
,
const
int32_t
padding_token_length
,
const
int32_t
*
token_index
,
T
*
output
)
{
int
batch
=
blockIdx
.
x
;
int
index
=
0
;
for
(
int
i
=
0
;
i
<
pre_sequnce_length
;
++
i
)
{
if
(
token_index
[
batch
*
padding_token_length
+
i
]
<
new_sequnce_length
)
{
output
[(
batch
*
new_sequnce_length
+
index
)
*
gridDim
.
y
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
]
=
tokens
[(
batch
*
pre_sequnce_length
+
i
)
*
gridDim
.
y
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
];
index
++
;
}
}
}
nvinfer1
::
DimsExprs
FusedTokenPrunePluginDynamic
::
getOutputDimensions
(
nvinfer1
::
DimsExprs
FusedTokenPrunePluginDynamic
::
getOutputDimensions
(
int
output_index
,
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
const
nvinfer1
::
DimsExprs
*
inputs
,
...
@@ -353,7 +254,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
...
@@ -353,7 +254,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
"should be half for varseqlen."
));
"should be half for varseqlen."
));
}
}
}
else
if
(
pos
==
6
||
pos
==
11
)
{
// mask_id, mask_id_out
}
else
if
(
pos
==
6
||
pos
==
11
)
{
// mask_id, mask_id_out
return
(
in
.
type
==
nvinfer1
::
DataType
::
k
FLOAT
)
&&
return
(
in
.
type
==
nvinfer1
::
DataType
::
k
HALF
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
else
{
}
else
{
return
in
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
return
in
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
...
@@ -364,7 +265,6 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
...
@@ -364,7 +265,6 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
if
(
with_fp16_
)
{
if
(
with_fp16_
)
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
return
(
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
else
{
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
...
@@ -373,8 +273,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
...
@@ -373,8 +273,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
0
];
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
0
];
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
else
{
}
else
{
return
in
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
return
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
}
}
}
}
}
...
@@ -425,199 +324,6 @@ size_t FusedTokenPrunePluginDynamic::getWorkspaceSize(
...
@@ -425,199 +324,6 @@ size_t FusedTokenPrunePluginDynamic::getWorkspaceSize(
return
size
;
return
size
;
}
}
template
<
typename
T
>
inline
void
enqueueImpl
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace_ptr
,
cudaStream_t
stream
,
int
device_id
,
T
max_value
,
bool
keep_first_token_
,
bool
keep_order_
)
{
// Dims
auto
attn_dims
=
input_desc
[
0
].
dims
;
auto
x_dims
=
input_desc
[
1
].
dims
;
auto
new_mask_dims
=
input_desc
[
3
].
dims
;
auto
bsz
=
attn_dims
.
d
[
0
],
nb_head
=
attn_dims
.
d
[
1
],
max_seq_len
=
attn_dims
.
d
[
2
];
auto
c
=
x_dims
.
d
[
2
];
auto
slimmed_x_len
=
new_mask_dims
.
d
[
2
];
// Inputs
const
T
*
attn_data
=
static_cast
<
const
T
*>
(
inputs
[
0
]);
const
T
*
x_data
=
static_cast
<
const
T
*>
(
inputs
[
1
]);
const
T
*
mask_data
=
static_cast
<
const
T
*>
(
inputs
[
2
]);
// Outputs
T
*
output_data
=
static_cast
<
T
*>
(
outputs
[
0
]);
int32_t
*
output_indices_data
=
static_cast
<
int32_t
*>
(
outputs
[
1
]);
int
total
=
bsz
*
nb_head
*
max_seq_len
*
max_seq_len
;
int
block
=
operators
::
ComputeBlockSize
(
max_seq_len
);
int
grid
=
operators
::
CeilDivide
(
total
,
block
);
// Workspace for intermediate variable
char
*
workspace
=
static_cast
<
char
*>
(
workspace_ptr
);
T
*
attn_tmp_data
=
reinterpret_cast
<
T
*>
(
workspace
);
size_t
offset
=
total
*
sizeof
(
T
);
T
*
attn_accu_data
=
reinterpret_cast
<
T
*>
(
workspace
+
offset
);
offset
+=
bsz
*
max_seq_len
*
sizeof
(
T
);
int32_t
*
attn_accu_indices_data
=
reinterpret_cast
<
int32_t
*>
(
workspace
+
offset
);
offset
+=
bsz
*
max_seq_len
*
sizeof
(
int32_t
);
T
*
sort_attn_accu_data
=
reinterpret_cast
<
T
*>
(
workspace
+
offset
);
offset
+=
bsz
*
max_seq_len
*
sizeof
(
T
);
int32_t
*
sort_attn_accu_indices_data
=
reinterpret_cast
<
int32_t
*>
(
workspace
+
offset
);
offset
+=
bsz
*
max_seq_len
*
sizeof
(
int32_t
);
int
*
offsets_data
=
reinterpret_cast
<
int
*>
(
workspace
+
offset
);
offset
+=
(
bsz
+
1
)
*
sizeof
(
int
);
int32_t
*
slimmed_sort_attn_accu_indices_data
=
reinterpret_cast
<
int32_t
*>
(
workspace
+
offset
);
// 1. Filter attn by mask
ElementwiseMask
<
T
>
<<<
grid
,
block
,
0
,
stream
>>>
(
attn_data
,
mask_data
,
attn_tmp_data
,
total
);
total
=
bsz
*
max_seq_len
;
block
=
operators
::
ComputeBlockSize
(
max_seq_len
);
grid
=
operators
::
CeilDivide
(
total
,
block
);
FillZero
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
attn_accu_data
,
total
);
// 2. Reduce sum
total
=
bsz
*
nb_head
*
max_seq_len
*
max_seq_len
;
int
block_tmp
=
max_seq_len
;
while
(
block_tmp
>
1024
)
block_tmp
/=
2
;
// if max seq len > 1024, it must be 2^n
block
=
block_tmp
;
// make sure max_seq_len is an integral multiple of block_size
grid
=
operators
::
CeilDivide
(
total
,
block
);
ReduceSum2
<
T
><<<
grid
,
block
,
block
*
sizeof
(
T
),
stream
>>>
(
attn_tmp_data
,
attn_accu_data
,
bsz
,
nb_head
,
max_seq_len
);
// 3. Prepare token indices
total
=
bsz
*
max_seq_len
;
block
=
operators
::
ComputeBlockSize
(
max_seq_len
);
grid
=
operators
::
CeilDivide
(
total
,
block
);
FillIndex
<<<
grid
,
block
,
0
,
stream
>>>
(
attn_accu_indices_data
,
bsz
,
max_seq_len
);
// 4. Sort token indices by attn
if
(
keep_first_token_
)
{
MaximumFirst
<
T
>
<<<
bsz
,
1
,
0
,
stream
>>>
(
attn_accu_data
,
bsz
,
max_seq_len
,
max_value
);
}
size_t
temp_storage_bytes
=
-
1
;
int
num_items
=
bsz
*
max_seq_len
;
int
num_segments
=
bsz
;
FillOffsets
<<<
bsz
+
1
,
1
,
0
,
stream
>>>
(
offsets_data
,
bsz
,
max_seq_len
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceSegmentedRadixSort
::
SortPairsDescending
(
nullptr
,
temp_storage_bytes
,
attn_accu_data
,
sort_attn_accu_data
,
attn_accu_indices_data
,
sort_attn_accu_indices_data
,
num_items
,
num_segments
,
offsets_data
,
offsets_data
+
1
,
0
,
sizeof
(
T
)
*
8
,
stream
));
int64_t
temp_size
=
temp_storage_bytes
;
phi
::
DenseTensor
temp_storage
;
auto
*
temp_storage_data
=
temp_storage
.
mutable_data
<
uint8_t
>
(
{
temp_size
},
platform
::
CUDAPlace
(
device_id
));
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceSegmentedRadixSort
::
SortPairsDescending
(
temp_storage_data
,
temp_storage_bytes
,
attn_accu_data
,
sort_attn_accu_data
,
attn_accu_indices_data
,
sort_attn_accu_indices_data
,
num_items
,
num_segments
,
offsets_data
,
offsets_data
+
1
,
0
,
sizeof
(
T
)
*
8
,
stream
));
// 5. Slice
total
=
bsz
*
slimmed_x_len
;
block
=
operators
::
ComputeBlockSize
(
slimmed_x_len
);
grid
=
operators
::
CeilDivide
(
total
,
block
);
Slice
<
int32_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
sort_attn_accu_indices_data
,
slimmed_sort_attn_accu_indices_data
,
bsz
,
max_seq_len
,
slimmed_x_len
);
if
(
keep_order_
)
{
// 6. reorder
num_items
=
bsz
*
slimmed_x_len
;
FillOffsets
<<<
bsz
+
1
,
1
,
0
,
stream
>>>
(
offsets_data
,
bsz
,
slimmed_x_len
);
temp_storage_bytes
=
-
1
;
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceSegmentedRadixSort
::
SortKeys
(
nullptr
,
temp_storage_bytes
,
slimmed_sort_attn_accu_indices_data
,
output_indices_data
,
num_items
,
num_segments
,
offsets_data
,
offsets_data
+
1
,
0
,
sizeof
(
int32_t
)
*
8
,
stream
));
temp_size
=
temp_storage_bytes
;
temp_storage
.
Resize
({
temp_size
});
temp_storage_data
=
temp_storage
.
mutable_data
<
uint8_t
>
(
platform
::
CUDAPlace
(
device_id
));
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceSegmentedRadixSort
::
SortKeys
(
temp_storage_data
,
temp_storage_bytes
,
slimmed_sort_attn_accu_indices_data
,
output_indices_data
,
num_items
,
num_segments
,
offsets_data
,
offsets_data
+
1
,
0
,
sizeof
(
int32_t
)
*
8
,
stream
));
TakeAlongAxis
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
x_data
,
output_data
,
output_indices_data
,
bsz
,
max_seq_len
,
slimmed_x_len
,
c
);
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMemcpy
(
output_indices_data
,
slimmed_sort_attn_accu_indices_data
,
bsz
*
slimmed_x_len
*
sizeof
(
int32_t
),
cudaMemcpyDeviceToDevice
));
TakeAlongAxis
<
T
>
<<<
grid
,
block
,
0
,
stream
>>>
(
x_data
,
output_data
,
slimmed_sort_attn_accu_indices_data
,
bsz
,
max_seq_len
,
slimmed_x_len
,
c
);
}
}
int
FusedTokenPrunePluginDynamic
::
enqueue
(
int
FusedTokenPrunePluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
...
@@ -628,49 +334,56 @@ int FusedTokenPrunePluginDynamic::enqueue(
...
@@ -628,49 +334,56 @@ int FusedTokenPrunePluginDynamic::enqueue(
if
(
flag_varseqlen_
)
{
if
(
flag_varseqlen_
)
{
if
(
!
(
input_desc
[
0
].
type
==
nvinfer1
::
DataType
::
kHALF
&&
if
(
!
(
input_desc
[
0
].
type
==
nvinfer1
::
DataType
::
kHALF
&&
input_desc
[
1
].
type
==
nvinfer1
::
DataType
::
kHALF
))
{
input_desc
[
1
].
type
==
nvinfer1
::
DataType
::
kHALF
))
{
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'type must half
"
));
"Token_prune'type must half for varseqlen
"
));
}
}
float
scale
=
float
scale
=
static_cast
<
float
>
(
input_desc
[
3
].
dims
.
d
[
2
])
/
input_desc
[
6
].
dims
.
d
[
1
];
static_cast
<
float
>
(
input_desc
[
3
].
dims
.
d
[
2
])
/
input_desc
[
2
].
dims
.
d
[
2
];
const
int32_t
*
inputs5
=
const
int32_t
*
input5
=
static_cast
<
const
int32_t
*>
(
inputs
[
5
]);
// pre pos id
static_cast
<
const
int32_t
*>
(
inputs
[
5
]);
// pre pos id
int32_t
*
outputs3
=
static_cast
<
int32_t
*>
(
outputs
[
3
]);
// new pos id
int32_t
*
output3
=
static_cast
<
int32_t
*>
(
outputs
[
3
]);
// new pos id
half
*
outputs0
=
static_cast
<
half
*>
(
outputs
[
0
]);
half
*
output0
=
static_cast
<
half
*>
(
outputs
[
0
]);
const
int32_t
B
=
input_desc
[
1
].
dims
.
d
[
0
];
// batchs
const
int32_t
B
=
input_desc
[
1
].
dims
.
d
[
0
];
// batchs
const
int32_t
max_sequnce_length
=
const
int32_t
max_sequnce_length
=
input_desc
[
1
].
dims
.
d
[
1
];
// max sequnce length
input_desc
[
1
].
dims
.
d
[
1
];
// max sequnce length
const
int32_t
length
=
input_desc
[
1
].
dims
.
d
[
2
];
//
vector length
const
int32_t
length
=
input_desc
[
1
].
dims
.
d
[
2
];
//
hidden size
const
half
*
scores
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
// reduce sum
const
half
*
scores
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
// reduce sum
const
half
*
tokens
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
const
half
*
tokens
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
const
int32_t
scores_size
=
B
*
max_sequnce_length
;
int32_t
padding_token_length
;
int32_t
padding_token_length
;
if
(
max_sequnce_length
<=
128
)
{
if
(
max_sequnce_length
<=
64
)
{
padding_token_length
=
64
;
}
else
if
(
max_sequnce_length
<=
128
)
{
padding_token_length
=
128
;
padding_token_length
=
128
;
}
else
if
(
max_sequnce_length
<=
256
)
{
}
else
if
(
max_sequnce_length
<=
256
)
{
padding_token_length
=
256
;
padding_token_length
=
256
;
}
else
if
(
max_sequnce_length
<=
384
)
{
}
else
if
(
max_sequnce_length
<=
384
)
{
padding_token_length
=
384
;
padding_token_length
=
384
;
}
else
if
(
max_sequnce_length
<=
512
)
{
padding_token_length
=
512
;
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'token_length must <=
384
"
));
"Token_prune'token_length must <=
512
"
));
}
}
// 1. Compute the token length after pruning.
// 1. Compute the token length after pruning.
compute_token_length
<<<
1
,
B
,
0
,
stream
>>>
(
compute_token_length
<<<
1
,
B
,
0
,
stream
>>>
(
input
s
5
,
pruned_token_lengths_
,
scale
);
input5
,
pruned_token_lengths_
,
scale
);
fill_index_padding_score
<<<
B
,
padding_token_length
,
0
,
stream
>>>
(
// 2. Padding scores
token_index_
,
scores
,
scores_size
,
padding_scores_
);
fill_index_padding_score
<
half
><<<
B
,
padding_token_length
,
0
,
stream
>>>
(
token_index_
,
scores
,
max_sequnce_length
,
static_cast
<
half
*>
(
padding_scores_
));
// 3. compute new pos id
// Determine temporary device storage requirements
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
temp_storage_bytes
,
pruned_token_lengths_
,
pruned_token_lengths_
,
output
s
3
,
output3
,
B
+
1
);
B
+
1
);
// Allocate temporary storage
// Allocate temporary storage
cudaMalloc
(
&
d_temp_storage
,
temp_storage_bytes
);
cudaMalloc
(
&
d_temp_storage
,
temp_storage_bytes
);
...
@@ -679,20 +392,28 @@ int FusedTokenPrunePluginDynamic::enqueue(
...
@@ -679,20 +392,28 @@ int FusedTokenPrunePluginDynamic::enqueue(
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
temp_storage_bytes
,
pruned_token_lengths_
,
pruned_token_lengths_
,
output
s
3
,
output3
,
B
+
1
);
B
+
1
);
if
(
padding_token_length
==
128
)
{
// 4. sort scores
general_topk_pair_sort
<
half
,
32
,
4
>
if
(
padding_token_length
==
64
)
{
<<<
B
,
32
,
0
,
stream
>>>
(
padding_scores_
,
token_index_
);
// 128
general_topk_pair_sort
<
half
,
32
,
2
><<<
B
,
32
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 64
}
else
if
(
padding_token_length
==
128
)
{
general_topk_pair_sort
<
half
,
32
,
4
><<<
B
,
32
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 128
}
else
if
(
padding_token_length
==
256
)
{
}
else
if
(
padding_token_length
==
256
)
{
general_topk_pair_sort
<
half
,
64
,
4
>
general_topk_pair_sort
<
half
,
64
,
4
><<<
B
,
64
,
0
,
stream
>>>
(
<<<
B
,
64
,
0
,
stream
>>>
(
padding_scores_
,
token_index_
);
// 256
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 256
}
else
if
(
padding_token_length
==
384
)
{
general_topk_pair_sort
<
half
,
96
,
4
><<<
B
,
96
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 384
}
else
{
}
else
{
general_topk_pair_sort
<
half
,
96
,
4
>
general_topk_pair_sort
<
half
,
128
,
4
><<<
B
,
128
,
0
,
stream
>>>
(
<<<
B
,
96
,
0
,
stream
>>>
(
padding_scores_
,
token_index_
);
// 384
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 512
}
}
// 5. compute output
int32_t
num_threads
;
int32_t
num_threads
;
if
(
length
<
1024
)
{
if
(
length
<
1024
)
{
num_threads
=
length
;
num_threads
=
length
;
...
@@ -723,46 +444,196 @@ int FusedTokenPrunePluginDynamic::enqueue(
...
@@ -723,46 +444,196 @@ int FusedTokenPrunePluginDynamic::enqueue(
B
,
B
,
max_sequnce_length
,
max_sequnce_length
,
length
/
num_threads
);
// batchs, max_sequnce_length, vector_ength/***
length
/
num_threads
);
// batchs, max_sequnce_length, vector_ength/***
varlen_prune_token
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
varlen_prune_token
_change_order
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
output
s3
,
token_index_
,
outputs
0
);
tokens
,
output
3
,
padding_token_length
,
token_index_
,
output
0
);
}
else
{
}
else
{
auto
input_type
=
input_desc
[
0
].
type
;
auto
input_type
=
input_desc
[
0
].
type
;
auto
attn_dims
=
input_desc
[
0
].
dims
;
const
int32_t
B
=
input_desc
[
1
].
dims
.
d
[
0
];
// batchs
auto
bsz
=
attn_dims
.
d
[
0
],
nb_head
=
attn_dims
.
d
[
1
],
const
int32_t
pre_sequnce_length
=
input_desc
[
1
].
dims
.
d
[
1
];
max_seq_len
=
attn_dims
.
d
[
2
];
const
int32_t
new_sequnce_length
=
input_desc
[
3
].
dims
.
d
[
2
];
// new mask
int
device_id
;
const
int32_t
length
=
input_desc
[
1
].
dims
.
d
[
2
];
// hidden size
cudaGetDevice
(
&
device_id
);
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. FusedTokenPrune-->fp32"
;
VLOG
(
1
)
<<
"TRT Plugin DataType selected. FusedTokenPrune-->fp32"
;
const
float
*
scores
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
// reduce sum
const
float
*
tokens
=
static_cast
<
const
float
*>
(
inputs
[
1
]);
// X
float
*
output0
=
static_cast
<
float
*>
(
outputs
[
0
]);
int32_t
padding_token_length
;
if
(
pre_sequnce_length
<=
64
)
{
padding_token_length
=
64
;
}
else
if
(
pre_sequnce_length
<=
128
)
{
padding_token_length
=
128
;
}
else
if
(
pre_sequnce_length
<=
256
)
{
padding_token_length
=
256
;
}
else
if
(
pre_sequnce_length
<=
384
)
{
padding_token_length
=
384
;
}
else
if
(
pre_sequnce_length
<=
512
)
{
padding_token_length
=
512
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'token_length must <= 512"
));
}
float
max
=
std
::
numeric_limits
<
float
>::
max
();
// 1. Padding scores
fill_index_padding_score
<
float
><<<
B
,
padding_token_length
,
0
,
stream
>>>
(
enqueueImpl
<
float
>
(
input_desc
,
token_index_
,
output_desc
,
scores
,
inputs
,
pre_sequnce_length
,
outputs
,
static_cast
<
float
*>
(
padding_scores_
));
workspace
,
stream
,
// 2. sort scores
device_id
,
if
(
padding_token_length
==
64
)
{
max
,
general_topk_pair_sort
<
float
,
32
,
2
><<<
B
,
32
,
0
,
stream
>>>
(
keep_first_token_
,
static_cast
<
float
*>
(
padding_scores_
),
token_index_
);
// 64
keep_order_
);
}
else
if
(
padding_token_length
==
128
)
{
general_topk_pair_sort
<
float
,
32
,
4
><<<
B
,
32
,
0
,
stream
>>>
(
static_cast
<
float
*>
(
padding_scores_
),
token_index_
);
// 128
}
else
if
(
padding_token_length
==
256
)
{
general_topk_pair_sort
<
float
,
64
,
4
><<<
B
,
64
,
0
,
stream
>>>
(
static_cast
<
float
*>
(
padding_scores_
),
token_index_
);
// 256
}
else
if
(
padding_token_length
==
384
)
{
general_topk_pair_sort
<
float
,
96
,
4
><<<
B
,
96
,
0
,
stream
>>>
(
static_cast
<
float
*>
(
padding_scores_
),
token_index_
);
// 384
}
else
{
general_topk_pair_sort
<
float
,
128
,
4
><<<
B
,
128
,
0
,
stream
>>>
(
static_cast
<
float
*>
(
padding_scores_
),
token_index_
);
// 512
}
// 3. compute output
int32_t
num_threads
;
if
(
length
<
1024
)
{
num_threads
=
length
;
}
else
{
if
(
length
%
512
==
0
)
{
num_threads
=
512
;
}
else
if
(
length
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
length
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
length
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
length
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
length
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
length
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
length
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
length
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
num_threads
=
1
;
}
}
if
(
keep_order_
)
{
const
dim3
num_blocks
(
B
,
length
/
num_threads
);
prune_token_keep_order
<
float
>
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
pre_sequnce_length
,
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
}
else
{
const
dim3
num_blocks
(
B
,
pre_sequnce_length
,
length
/
num_threads
);
prune_token_change_order
<
float
>
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
}
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. FusedTokenPrune-->fp16"
;
VLOG
(
1
)
<<
"TRT Plugin DataType selected. FusedTokenPrune-->fp16"
;
const
half
*
scores
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
// reduce sum
const
half
*
tokens
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
// X
half
*
output0
=
static_cast
<
half
*>
(
outputs
[
0
]);
int32_t
padding_token_length
;
if
(
pre_sequnce_length
<=
64
)
{
padding_token_length
=
64
;
}
else
if
(
pre_sequnce_length
<=
128
)
{
padding_token_length
=
128
;
}
else
if
(
pre_sequnce_length
<=
256
)
{
padding_token_length
=
256
;
}
else
if
(
pre_sequnce_length
<=
384
)
{
padding_token_length
=
384
;
}
else
if
(
pre_sequnce_length
<=
512
)
{
padding_token_length
=
512
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'token_length must <= 512"
));
}
// 1. Padding scores
fill_index_padding_score
<
half
><<<
B
,
padding_token_length
,
0
,
stream
>>>
(
token_index_
,
scores
,
pre_sequnce_length
,
static_cast
<
half
*>
(
padding_scores_
));
// 2. sort scores
if
(
padding_token_length
==
64
)
{
general_topk_pair_sort
<
half
,
32
,
2
><<<
B
,
32
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 64
}
else
if
(
padding_token_length
==
128
)
{
general_topk_pair_sort
<
half
,
32
,
4
><<<
B
,
32
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 128
}
else
if
(
padding_token_length
==
256
)
{
general_topk_pair_sort
<
half
,
64
,
4
><<<
B
,
64
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 256
}
else
if
(
padding_token_length
==
384
)
{
general_topk_pair_sort
<
half
,
96
,
4
><<<
B
,
96
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 384
}
else
{
general_topk_pair_sort
<
half
,
128
,
4
><<<
B
,
128
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 512
}
half
max
=
65504.0
;
// 3. compute output
enqueueImpl
<
half
>
(
input_desc
,
int32_t
num_threads
;
output_desc
,
if
(
length
<
1024
)
{
inputs
,
num_threads
=
length
;
outputs
,
}
else
{
workspace
,
if
(
length
%
512
==
0
)
{
stream
,
num_threads
=
512
;
device_id
,
}
else
if
(
length
%
256
==
0
)
{
max
,
num_threads
=
256
;
keep_first_token_
,
}
else
if
(
length
%
128
==
0
)
{
keep_order_
);
num_threads
=
128
;
}
else
if
(
length
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
length
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
length
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
length
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
length
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
length
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
num_threads
=
1
;
}
}
if
(
keep_order_
)
{
const
dim3
num_blocks
(
B
,
length
/
num_threads
);
prune_token_keep_order
<
half
>
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
pre_sequnce_length
,
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
}
else
{
const
dim3
num_blocks
(
B
,
pre_sequnce_length
,
length
/
num_threads
);
prune_token_change_order
<
half
>
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
}
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The FusedTokenPrune TRT Plugin's input type "
platform
::
errors
::
Fatal
(
"The FusedTokenPrune TRT Plugin's input type "
...
...
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h
浏览文件 @
65c17315
...
@@ -93,12 +93,33 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
...
@@ -93,12 +93,33 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
int
nb_outputs
)
TRT_NOEXCEPT
override
{
int
nb_outputs
)
TRT_NOEXCEPT
override
{
max_batchs_
=
in
[
1
].
max
.
d
[
0
];
max_batchs_
=
in
[
1
].
max
.
d
[
0
];
max_token_length_
=
in
[
1
].
max
.
d
[
1
];
max_token_length_
=
in
[
1
].
max
.
d
[
1
];
int32_t
padding_token_length
;
if
(
max_token_length_
<=
64
)
{
padding_token_length
=
64
;
}
else
if
(
max_token_length_
<=
128
)
{
padding_token_length
=
128
;
}
else
if
(
max_token_length_
<=
256
)
{
padding_token_length
=
256
;
}
else
if
(
max_token_length_
<=
384
)
{
padding_token_length
=
384
;
}
else
if
(
max_token_length_
<=
512
)
{
padding_token_length
=
512
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'token_length(max) must <= 512"
));
}
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
pruned_token_lengths_
,
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
pruned_token_lengths_
,
(
max_batchs_
+
1
)
*
sizeof
(
int32_t
)));
(
max_batchs_
+
1
)
*
sizeof
(
int32_t
)));
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
token_index_
,
max_batchs_
*
max_token_length_
*
sizeof
(
int32_t
)));
&
token_index_
,
max_batchs_
*
padding_token_length
*
sizeof
(
int32_t
)));
int32_t
type_size
=
4
;
if
(
in
[
0
].
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
)
{
type_size
=
2
;
}
else
{
type_size
=
4
;
}
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
padding_scores_
,
max_batchs_
*
max_token_length_
*
sizeof
(
half
)
));
&
padding_scores_
,
max_batchs_
*
padding_token_length
*
type_size
));
}
}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
...
@@ -129,7 +150,7 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
...
@@ -129,7 +150,7 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
int32_t
*
token_index_
;
int32_t
*
token_index_
;
int32_t
max_batchs_
;
int32_t
max_batchs_
;
int32_t
max_token_length_
;
int32_t
max_token_length_
;
half
*
padding_scores_
;
void
*
padding_scores_
;
};
};
class
FusedTokenPrunePluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
class
FusedTokenPrunePluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
...
...
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
浏览文件 @
65c17315
...
@@ -352,24 +352,24 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
...
@@ -352,24 +352,24 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
ctx_
->
PartialInitWithAllocator
();
ctx_
->
PartialInitWithAllocator
();
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
min_input_shape
=
{
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
min_input_shape
=
{
{
"attn"
,
{
4
,
1
,
4
,
4
}},
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
=
{
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
=
{
{
"attn"
,
{
4
,
1
,
4
,
4
}},
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape
=
{
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape
=
{
{
"attn"
,
{
4
,
1
,
4
,
4
}},
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
engine_
=
new
TensorRTEngine
(
16
,
engine_
=
new
TensorRTEngine
(
16
,
1
<<
10
,
1
<<
10
,
AnalysisConfig
::
Precision
::
k
Half
,
AnalysisConfig
::
Precision
::
k
Float32
,
nullptr
,
nullptr
,
0
,
0
,
min_input_shape
,
min_input_shape
,
...
@@ -391,7 +391,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
...
@@ -391,7 +391,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
}
}
}
}
void
PrepareInputOutput
(
const
std
::
vector
<
std
::
vector
<
float
16
>>
inputs
,
void
PrepareInputOutput
(
const
std
::
vector
<
std
::
vector
<
float
>>
inputs
,
std
::
vector
<
std
::
vector
<
int
>>
output_shapes
)
{
std
::
vector
<
std
::
vector
<
int
>>
output_shapes
)
{
LOG
(
INFO
)
<<
"PrepareInputOutput"
;
LOG
(
INFO
)
<<
"PrepareInputOutput"
;
int
num_inputs
=
inputs
.
size
();
int
num_inputs
=
inputs
.
size
();
...
@@ -423,15 +423,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
...
@@ -423,15 +423,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000)
#if IS_TRT_VERSION_GE(8000)
tensorrt
::
plugin
::
TrtPluginRegistry
::
Global
()
->
RegistToTrt
();
tensorrt
::
plugin
::
TrtPluginRegistry
::
Global
()
->
RegistToTrt
();
auto
*
attn
=
engine_
->
DeclareInput
(
auto
*
attn
=
engine_
->
DeclareInput
(
"attn"
,
nvinfer1
::
DataType
::
k
HALF
,
nvinfer1
::
Dims4
{
-
1
,
1
,
4
,
4
});
"attn"
,
nvinfer1
::
DataType
::
k
FLOAT
,
nvinfer1
::
Dims2
{
-
1
,
4
});
auto
*
x
=
engine_
->
DeclareInput
(
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
k
HALF
,
nvinfer1
::
Dims3
{
-
1
,
4
,
1
});
"x"
,
nvinfer1
::
DataType
::
k
FLOAT
,
nvinfer1
::
Dims3
{
-
1
,
4
,
1
});
auto
*
mask
=
engine_
->
DeclareInput
(
auto
*
mask
=
engine_
->
DeclareInput
(
"mask"
,
nvinfer1
::
DataType
::
k
HALF
,
nvinfer1
::
Dims4
{
-
1
,
1
,
4
,
4
});
"mask"
,
nvinfer1
::
DataType
::
k
FLOAT
,
nvinfer1
::
Dims4
{
-
1
,
1
,
4
,
4
});
auto
*
new_mask
=
engine_
->
DeclareInput
(
auto
*
new_mask
=
engine_
->
DeclareInput
(
"new_mask"
,
nvinfer1
::
DataType
::
k
HALF
,
nvinfer1
::
Dims4
{
-
1
,
1
,
2
,
2
});
"new_mask"
,
nvinfer1
::
DataType
::
k
FLOAT
,
nvinfer1
::
Dims4
{
-
1
,
1
,
2
,
2
});
plugin
::
FusedTokenPrunePluginDynamic
*
plugin
=
plugin
::
FusedTokenPrunePluginDynamic
*
plugin
=
new
plugin
::
FusedTokenPrunePluginDynamic
(
tru
e
,
new
plugin
::
FusedTokenPrunePluginDynamic
(
/*with_fp16*/
fals
e
,
/*keep_first_token*/
false
,
/*keep_first_token*/
false
,
/*keep_order*/
true
,
/*keep_order*/
true
,
/*flag_varseqlen*/
false
);
/*flag_varseqlen*/
false
);
...
@@ -449,18 +449,215 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
...
@@ -449,18 +449,215 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
6
);
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
6
);
LOG
(
INFO
)
<<
"create input"
;
LOG
(
INFO
)
<<
"create input"
;
std
::
vector
<
float16
>
attn_v
(
64
);
std
::
vector
<
float
>
attn_v
(
16
);
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
attn_v
[
j
*
4
+
k
]
=
k
;
}
}
std
::
vector
<
float
>
x_v
(
16
);
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
x_v
[
i
*
4
+
j
]
=
4
-
j
;
}
}
std
::
vector
<
float
>
mask_v
(
64
);
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
attn_v
[
i
*
16
+
j
*
4
+
k
]
=
k
;
mask_v
[
i
*
16
+
j
*
4
+
k
]
=
1
;
}
}
}
}
}
}
std
::
vector
<
float
>
new_mask_v
(
16
);
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
for
(
int
k
=
0
;
k
<
2
;
++
k
)
{
new_mask_v
[
i
*
4
+
j
*
2
+
k
]
=
1
;
}
}
}
LOG
(
INFO
)
<<
"create output"
;
std
::
vector
<
int
>
out_slimmed_x_shape
{
4
,
2
,
1
};
std
::
vector
<
int
>
out_cls_ins_shape
{
4
,
2
};
PrepareInputOutput
({
attn_v
,
x_v
,
mask_v
,
new_mask_v
},
{
out_slimmed_x_shape
,
out_cls_ins_shape
});
auto
*
attn_gpu_data
=
inputs_
[
0
].
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
x_gpu_data
=
inputs_
[
1
].
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
mask_gpu_data
=
inputs_
[
2
].
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
new_mask_gpu_data
=
inputs_
[
3
].
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
slimmed_x_gpu_data
=
outputs_
[
0
].
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
cls_inds_gpu_data
=
outputs_
[
1
].
mutable_data
<
int32_t
>
(
ctx_
->
GetPlace
());
LOG
(
INFO
)
<<
"create buffers"
;
std
::
vector
<
void
*>
buffers
(
6
);
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
attn_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
x_gpu_data
);
buffers
[
2
]
=
reinterpret_cast
<
void
*>
(
mask_gpu_data
);
buffers
[
3
]
=
reinterpret_cast
<
void
*>
(
new_mask_gpu_data
);
buffers
[
4
]
=
reinterpret_cast
<
void
*>
(
slimmed_x_gpu_data
);
buffers
[
5
]
=
reinterpret_cast
<
void
*>
(
cls_inds_gpu_data
);
LOG
(
INFO
)
<<
"Execute"
;
engine_
->
Execute
(
4
,
&
buffers
,
ctx_
->
stream
());
std
::
vector
<
float
>
slimmed_x_v
(
8
);
std
::
vector
<
int32_t
>
cls_inds_v
;
LOG
(
INFO
)
<<
"GetOutput"
;
GetOutput
(
slimmed_x_v
,
cls_inds_v
);
// slimmed_x_v: [[4,3,2,1],[4,3,2,1],[4,3,2,1],[4,3,2,1]] ->
// [[2,1],[2,1],[2,1],[2,1]]
ASSERT_EQ
(
slimmed_x_v
[
0
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
1
],
1
);
ASSERT_EQ
(
slimmed_x_v
[
2
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
3
],
1
);
ASSERT_EQ
(
slimmed_x_v
[
4
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
5
],
1
);
ASSERT_EQ
(
slimmed_x_v
[
6
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
7
],
1
);
LOG
(
INFO
)
<<
"finish"
;
#endif
}
class
TensorRTDynamicTestFusedTokenPruneHalf
:
public
::
testing
::
Test
{
protected:
void
SetUp
()
override
{
ctx_
=
new
phi
::
GPUContext
(
platform
::
CUDAPlace
(
0
));
ctx_
->
SetAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
platform
::
CUDAPlace
(
0
),
ctx_
->
stream
())
.
get
());
ctx_
->
SetHostAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
paddle
::
platform
::
CPUPlace
())
.
get
());
ctx_
->
SetZeroAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetZeroAllocator
(
platform
::
CUDAPlace
(
0
))
.
get
());
ctx_
->
SetPinnedAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
paddle
::
platform
::
CUDAPinnedPlace
())
.
get
());
ctx_
->
PartialInitWithAllocator
();
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
min_input_shape
=
{
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
=
{
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape
=
{
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
engine_
=
new
TensorRTEngine
(
16
,
1
<<
10
,
AnalysisConfig
::
Precision
::
kHalf
,
nullptr
,
0
,
min_input_shape
,
max_input_shape
,
optim_input_shape
,
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(),
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(),
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(),
false
,
phi
::
DataType
::
FLOAT16
,
NaiveLogger
::
Global
());
engine_
->
InitNetwork
();
}
void
TearDown
()
override
{
if
(
engine_
)
{
delete
engine_
;
engine_
=
nullptr
;
}
}
void
PrepareInputOutput
(
const
std
::
vector
<
std
::
vector
<
float16
>>
inputs
,
std
::
vector
<
std
::
vector
<
int
>>
output_shapes
)
{
LOG
(
INFO
)
<<
"PrepareInputOutput"
;
int
num_inputs
=
inputs
.
size
();
int
num_outputs
=
output_shapes
.
size
();
inputs_
.
resize
(
num_inputs
);
outputs_
.
resize
(
num_outputs
);
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
paddle
::
framework
::
TensorFromVector
(
inputs
[
i
],
*
ctx_
,
&
inputs_
[
i
]);
}
for
(
int
i
=
0
;
i
<
num_outputs
;
++
i
)
{
outputs_
[
i
].
Resize
(
phi
::
make_ddim
(
output_shapes
[
i
]));
}
}
void
GetOutput
(
std
::
vector
<
float
>
&
slimmed_x
,
// NOLINT
std
::
vector
<
int32_t
>
&
cls_inds
)
{
// NOLINT
paddle
::
framework
::
TensorToVector
(
outputs_
[
0
],
*
ctx_
,
&
slimmed_x
);
paddle
::
framework
::
TensorToVector
(
outputs_
[
1
],
*
ctx_
,
&
cls_inds
);
}
protected:
std
::
vector
<
phi
::
DenseTensor
>
inputs_
;
std
::
vector
<
phi
::
DenseTensor
>
outputs_
;
TensorRTEngine
*
engine_
;
phi
::
GPUContext
*
ctx_
;
};
TEST_F
(
TensorRTDynamicTestFusedTokenPruneHalf
,
test_fused_token_prune
)
{
#if IS_TRT_VERSION_GE(8000)
tensorrt
::
plugin
::
TrtPluginRegistry
::
Global
()
->
RegistToTrt
();
auto
*
attn
=
engine_
->
DeclareInput
(
"attn"
,
nvinfer1
::
DataType
::
kHALF
,
nvinfer1
::
Dims2
{
-
1
,
4
});
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kHALF
,
nvinfer1
::
Dims3
{
-
1
,
4
,
1
});
auto
*
mask
=
engine_
->
DeclareInput
(
"mask"
,
nvinfer1
::
DataType
::
kHALF
,
nvinfer1
::
Dims4
{
-
1
,
1
,
4
,
4
});
auto
*
new_mask
=
engine_
->
DeclareInput
(
"new_mask"
,
nvinfer1
::
DataType
::
kHALF
,
nvinfer1
::
Dims4
{
-
1
,
1
,
2
,
2
});
plugin
::
FusedTokenPrunePluginDynamic
*
plugin
=
new
plugin
::
FusedTokenPrunePluginDynamic
(
/*with_fp16*/
true
,
/*keep_first_token*/
false
,
/*keep_order*/
true
,
/*flag_varseqlen*/
false
);
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
=
{
attn
,
x
,
mask
,
new_mask
};
auto
*
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
4
,
plugin
);
PADDLE_ENFORCE_NOT_NULL
(
layer
,
platform
::
errors
::
InvalidArgument
(
"TRT fused_token_prune layer building failed."
));
std
::
vector
<
std
::
string
>
output_tensor_names
{
"out_slimmed_x"
,
"out_cls_inds"
};
for
(
size_t
i
=
0
;
i
<
2
;
i
++
)
{
layer
->
getOutput
(
i
)
->
setName
(
output_tensor_names
[
i
].
c_str
());
engine_
->
DeclareOutput
(
layer
,
i
,
output_tensor_names
[
i
]);
}
engine_
->
FreezeNetwork
();
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
6
);
LOG
(
INFO
)
<<
"create input"
;
std
::
vector
<
float16
>
attn_v
(
16
);
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
attn_v
[
j
*
4
+
k
]
=
k
;
}
}
std
::
vector
<
float16
>
x_v
(
16
);
std
::
vector
<
float16
>
x_v
(
16
);
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
x_v
[
i
*
4
+
j
]
=
1
;
x_v
[
i
*
4
+
j
]
=
4
-
j
;
}
}
}
}
std
::
vector
<
float16
>
mask_v
(
64
);
std
::
vector
<
float16
>
mask_v
(
64
);
...
@@ -509,20 +706,24 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
...
@@ -509,20 +706,24 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
engine_
->
Execute
(
4
,
&
buffers
,
ctx_
->
stream
());
engine_
->
Execute
(
4
,
&
buffers
,
ctx_
->
stream
());
std
::
vector
<
float
>
slimmed_x_v
;
std
::
vector
<
float
>
slimmed_x_v
(
8
)
;
std
::
vector
<
int32_t
>
cls_inds_v
;
std
::
vector
<
int32_t
>
cls_inds_v
;
LOG
(
INFO
)
<<
"GetOutput"
;
LOG
(
INFO
)
<<
"GetOutput"
;
GetOutput
(
slimmed_x_v
,
cls_inds_v
);
GetOutput
(
slimmed_x_v
,
cls_inds_v
);
ASSERT_EQ
(
cls_inds_v
[
0
],
2
);
// slimmed_x_v: [[4,3,2,1],[4,3,2,1],[4,3,2,1],[4,3,2,1]] ->
ASSERT_EQ
(
cls_inds_v
[
1
],
3
);
// [[2,1],[2,1],[2,1],[2,1]]
ASSERT_EQ
(
cls_inds_v
[
2
],
2
);
ASSERT_EQ
(
cls_inds_v
[
3
],
3
);
ASSERT_EQ
(
slimmed_x_v
[
0
],
2
);
ASSERT_EQ
(
cls_inds_v
[
4
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
1
],
1
);
ASSERT_EQ
(
cls_inds_v
[
5
],
3
);
ASSERT_EQ
(
slimmed_x_v
[
2
],
2
);
ASSERT_EQ
(
cls_inds_v
[
6
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
3
],
1
);
ASSERT_EQ
(
cls_inds_v
[
7
],
3
);
ASSERT_EQ
(
slimmed_x_v
[
4
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
5
],
1
);
ASSERT_EQ
(
slimmed_x_v
[
6
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
7
],
1
);
LOG
(
INFO
)
<<
"finish"
;
LOG
(
INFO
)
<<
"finish"
;
#endif
#endif
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录