Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
65c17315
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
65c17315
编写于
12月 21, 2022
作者:
W
Wangzheee
提交者:
GitHub
12月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference]optimize token prune for no varlen (#49094)
* optimize token prune for no varlen
上级
4cdeab7b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
546 addition
and
450 deletion
+546
-450
paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc
.../fluid/inference/tensorrt/convert/fused_token_prune_op.cc
+18
-15
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
.../inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
+281
-410
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h
...d/inference/tensorrt/plugin/fused_token_prune_op_plugin.h
+24
-3
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
+223
-22
未找到文件。
paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc
浏览文件 @
65c17315
...
...
@@ -38,6 +38,17 @@ class FusedTokenPruneOpConverter : public OpConverter {
auto
output_name
=
op_desc
.
Output
(
"SlimmedX"
)[
0
];
auto
out_inds_name
=
op_desc
.
Output
(
"CLSInds"
)[
0
];
if
(
engine_
->
with_dynamic_shape
())
{
// reduce_sum: (-1,headsize,token_length,token_length) ->
// (-1,token_length)
uint32_t
reduce_dim
=
0
;
reduce_dim
|=
1
<<
1
;
// 00000000000000000000000000000010
reduce_dim
|=
1
<<
2
;
// 00000000000000000000000000000110
bool
keep_dim
=
false
;
nvinfer1
::
ReduceOperation
reduce_type
=
nvinfer1
::
ReduceOperation
::
kSUM
;
auto
*
reduce_sum_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
Attn
,
reduce_type
,
reduce_dim
,
keep_dim
);
auto
*
Reduced
=
reduce_sum_layer
->
getOutput
(
0
);
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
...
...
@@ -53,21 +64,10 @@ class FusedTokenPruneOpConverter : public OpConverter {
auto
*
pos_id
=
engine_
->
GetITensor
(
"pos_id"
);
auto
*
mask_id
=
engine_
->
GetITensor
(
"mask_id"
);
// reduce_sum: (-1,headsize,token_length,token_length) ->
// (-1,token_length)
uint32_t
reduce_dim
=
0
;
reduce_dim
|=
1
<<
1
;
// 00000000000000000000000000000010
reduce_dim
|=
1
<<
2
;
// 00000000000000000000000000000110
bool
keep_dim
=
false
;
nvinfer1
::
ReduceOperation
reduce_type
=
nvinfer1
::
ReduceOperation
::
kSUM
;
auto
*
reduce_sum_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
Attn
,
reduce_type
,
reduce_dim
,
keep_dim
);
// reduce_sum_layer->getOutput(0)->setType(reduce_sum_layer->getInput(0)->getType());
auto
*
Reduced
=
reduce_sum_layer
->
getOutput
(
0
);
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
=
{
Reduced
,
X
,
Mask
,
NewMask
,
word_id
,
pos_id
,
mask_id
};
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
7
,
plugin
);
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
itensors
.
size
(),
plugin
);
// inputs'number: 7
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
...
...
@@ -87,10 +87,13 @@ class FusedTokenPruneOpConverter : public OpConverter {
layer
->
getOutput
(
4
)
->
setName
(
"mask_id_after_token_prune"
);
engine_
->
SetITensor
(
"mask_id"
,
layer
->
getOutput
(
4
));
}
else
{
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
=
{
Attn
,
X
,
Mask
,
NewMask
};
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
4
,
plugin
);
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
=
{
Reduced
,
X
,
Mask
,
NewMask
};
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
itensors
.
size
(),
plugin
);
// inputs'number: 4
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
layer
->
getOutput
(
1
)
->
setName
(
out_inds_name
.
c_str
());
engine_
->
SetITensor
(
out_inds_name
,
layer
->
getOutput
(
1
));
}
...
...
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
浏览文件 @
65c17315
...
...
@@ -31,150 +31,6 @@ namespace inference {
namespace
tensorrt
{
namespace
plugin
{
template
<
typename
T
>
__global__
void
ElementwiseMask
(
const
T
*
a
,
const
T
*
b
,
T
*
res
,
int
num_elements
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
auto
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
num_elements
)
return
;
const
T
zero
=
0
;
res
[
tid
]
=
b
[
tid
]
>=
zero
?
a
[
tid
]
:
zero
;
#endif
}
template
<
typename
T
>
__global__
void
FillZero
(
T
*
data
,
int
len
)
{
auto
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
len
)
return
;
const
T
zero
=
0
;
data
[
tid
]
=
zero
;
}
__global__
void
FillIndex
(
int32_t
*
indices
,
int
num_raws
,
int
num_cols
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
num_raws
*
num_cols
)
return
;
int
col
=
tid
%
num_cols
;
int
raw
=
tid
/
num_cols
;
indices
[
tid
]
=
col
;
}
template
<
typename
T
>
__global__
void
MaximumFirst
(
T
*
mat
,
int
num_raws
,
int
num_cols
,
T
max_value
)
{
auto
raw
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
raw
>=
num_raws
)
return
;
mat
[
raw
*
num_cols
]
=
max_value
;
}
__global__
void
FillOffsets
(
int
*
offsets
,
int
num_raws
,
int
num_cols
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>
num_raws
)
return
;
offsets
[
tid
]
=
tid
*
num_cols
;
}
template
<
typename
T
>
__global__
void
Slice
(
const
T
*
src
,
T
*
dst
,
int
num_raws
,
int
src_num_cols
,
int
dst_num_cols
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
num_raws
*
dst_num_cols
)
return
;
int
raw
=
tid
/
dst_num_cols
;
int
col
=
tid
%
dst_num_cols
;
dst
[
tid
]
=
src
[
raw
*
src_num_cols
+
col
];
}
template
<
typename
T
>
__global__
void
ReduceSum2
(
const
T
*
src
,
T
*
dst
,
int
bsz
,
int
nb_head
,
int
max_seq_len
)
{
int
tid
=
threadIdx
.
x
;
int
bid
=
blockIdx
.
x
;
int
num_blocks_per_head
=
((
max_seq_len
/
blockDim
.
x
)
*
max_seq_len
);
int
batch
=
bid
/
(
nb_head
*
num_blocks_per_head
);
int
col
=
bid
%
max_seq_len
;
int
head
=
(
bid
/
num_blocks_per_head
)
%
nb_head
;
extern
__shared__
T
res_float
[];
res_float
[
tid
]
=
src
[
batch
*
(
nb_head
*
max_seq_len
*
max_seq_len
)
+
head
*
(
max_seq_len
*
max_seq_len
)
+
col
+
tid
*
max_seq_len
];
__syncthreads
();
for
(
int
offset
=
blockDim
.
x
>>
1
;
offset
>
0
;
offset
>>=
1
)
{
if
(
tid
<
offset
)
{
res_float
[
tid
]
+=
res_float
[
tid
+
offset
];
}
__syncthreads
();
if
(
offset
%
2
==
1
&&
tid
==
offset
-
2
)
{
res_float
[
tid
]
+=
res_float
[
tid
+
1
];
}
}
if
(
tid
==
0
)
{
auto
*
dst_addr
=
dst
+
batch
*
max_seq_len
+
col
;
atomicAdd
(
dst_addr
,
res_float
[
0
]);
}
}
template
<
>
__global__
void
ReduceSum2
<
half
>
(
const
half
*
src
,
half
*
dst
,
int
bsz
,
int
nb_head
,
int
max_seq_len
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int
tid
=
threadIdx
.
x
;
int
bid
=
blockIdx
.
x
;
int
num_blocks_per_head
=
((
max_seq_len
/
blockDim
.
x
)
*
max_seq_len
);
int
batch
=
bid
/
(
nb_head
*
num_blocks_per_head
);
int
col
=
bid
%
max_seq_len
;
int
head
=
(
bid
/
num_blocks_per_head
)
%
nb_head
;
extern
__shared__
half
res_half
[];
res_half
[
tid
]
=
src
[
batch
*
(
nb_head
*
max_seq_len
*
max_seq_len
)
+
head
*
(
max_seq_len
*
max_seq_len
)
+
col
+
tid
*
max_seq_len
];
__syncthreads
();
for
(
int
offset
=
blockDim
.
x
>>
1
;
offset
>
0
;
offset
>>=
1
)
{
if
(
tid
<
offset
)
{
res_half
[
tid
]
+=
res_half
[
tid
+
offset
];
}
__syncthreads
();
if
(
offset
%
2
==
1
&&
tid
==
offset
-
2
)
{
res_half
[
tid
]
+=
res_half
[
tid
+
1
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
phi
::
fastAtomicAdd
<
platform
::
float16
>
(
reinterpret_cast
<
platform
::
float16
*>
(
dst
),
static_cast
<
size_t
>
(
batch
*
max_seq_len
+
col
),
static_cast
<
size_t
>
(
bsz
*
max_seq_len
),
static_cast
<
platform
::
float16
>
(
res_half
[
0
]));
}
#endif
}
template
<
typename
T
>
__global__
void
TakeAlongAxis
(
const
T
*
src
,
T
*
dst
,
int32_t
*
indices
,
int
num_raws
,
int
src_num_cols
,
int
dst_num_cols
,
int
num_elements
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
num_raws
*
dst_num_cols
)
return
;
int
raw
=
tid
/
dst_num_cols
;
int
col
=
tid
%
dst_num_cols
;
for
(
int
i
=
0
;
i
<
num_elements
;
++
i
)
{
dst
[
tid
*
num_elements
+
i
]
=
*
(
src
+
(
raw
*
src_num_cols
+
indices
[
tid
])
*
num_elements
+
i
);
}
}
__global__
void
compute_token_length
(
const
int32_t
*
src
,
int32_t
*
dst
,
float
scale
)
{
...
...
@@ -182,16 +38,18 @@ __global__ void compute_token_length(const int32_t* src,
dst
[
it
]
=
max
(
static_cast
<
int
>
((
src
[
it
+
1
]
-
src
[
it
])
*
scale
),
1
);
}
template
<
typename
T
>
__global__
void
fill_index_padding_score
(
int32_t
*
token_index
,
const
half
*
scores
,
int32_t
scores_size
,
half
*
padding_scores
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
token_index
[
tid
]
=
threadIdx
.
x
;
if
(
tid
<
scores_size
)
{
padding_scores
[
tid
]
=
scores
[
tid
];
const
T
*
scores
,
int32_t
sequnce_length
,
T
*
padding_scores
)
{
int
padding_scores_it
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
scores_it
=
threadIdx
.
x
+
blockIdx
.
x
*
sequnce_length
;
token_index
[
padding_scores_it
]
=
threadIdx
.
x
;
if
(
threadIdx
.
x
<
sequnce_length
)
{
padding_scores
[
padding_scores_it
]
=
scores
[
scores_it
];
}
else
{
padding_scores
[
tid
]
=
0
;
padding_scores
[
padding_scores_it
]
=
0
;
}
}
...
...
@@ -238,21 +96,64 @@ __global__ void general_topk_pair_sort(T* in_keys, int32_t* in_out_values) {
.
Store
(
in_out_values
+
block_offset
,
thread_values
);
}
__global__
void
varlen_prune_token
(
const
half
*
tokens
,
const
int32_t
*
token_pos
,
const
int32_t
*
token_index
,
half
*
output
)
{
__global__
void
varlen_prune_token_change_order
(
const
half
*
tokens
,
const
int32_t
*
token_pos
,
const
int32_t
padding_token_length
,
const
int32_t
*
token_index
,
half
*
output
)
{
int
batch
=
blockIdx
.
x
;
int
token_it
=
batch
*
gridDim
.
y
+
blockIdx
.
y
;
int
pre_value_it
=
token_it
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
;
int
token_index_it
=
batch
*
padding_token_length
+
blockIdx
.
y
;
if
(
token_index
[
token_it
]
<
token_pos
[
batch
+
1
]
-
token_pos
[
batch
])
{
output
[(
token_index
[
token_it
]
+
token_pos
[
batch
])
*
gridDim
.
z
*
blockDim
.
x
+
if
(
token_index
[
token_index_it
]
<
token_pos
[
batch
+
1
]
-
token_pos
[
batch
])
{
output
[(
token_index
[
token_index_it
]
+
token_pos
[
batch
])
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
tokens
[
pre_value_it
];
}
}
template
<
typename
T
>
__global__
void
prune_token_change_order
(
const
T
*
tokens
,
int32_t
new_sequnce_length
,
const
int32_t
padding_token_length
,
const
int32_t
*
token_index
,
T
*
output
)
{
int
batch
=
blockIdx
.
x
;
int
token_it
=
batch
*
gridDim
.
y
+
blockIdx
.
y
;
int
pre_value_it
=
token_it
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
;
int
token_index_it
=
batch
*
padding_token_length
+
blockIdx
.
y
;
if
(
token_index
[
token_index_it
]
<
new_sequnce_length
)
{
output
[(
batch
*
new_sequnce_length
+
token_index
[
token_index_it
])
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
tokens
[
pre_value_it
];
}
}
template
<
typename
T
>
__global__
void
prune_token_keep_order
(
const
T
*
tokens
,
int32_t
pre_sequnce_length
,
int32_t
new_sequnce_length
,
const
int32_t
padding_token_length
,
const
int32_t
*
token_index
,
T
*
output
)
{
int
batch
=
blockIdx
.
x
;
int
index
=
0
;
for
(
int
i
=
0
;
i
<
pre_sequnce_length
;
++
i
)
{
if
(
token_index
[
batch
*
padding_token_length
+
i
]
<
new_sequnce_length
)
{
output
[(
batch
*
new_sequnce_length
+
index
)
*
gridDim
.
y
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
]
=
tokens
[(
batch
*
pre_sequnce_length
+
i
)
*
gridDim
.
y
*
blockDim
.
x
+
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
];
index
++
;
}
}
}
nvinfer1
::
DimsExprs
FusedTokenPrunePluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
...
...
@@ -353,7 +254,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
"should be half for varseqlen."
));
}
}
else
if
(
pos
==
6
||
pos
==
11
)
{
// mask_id, mask_id_out
return
(
in
.
type
==
nvinfer1
::
DataType
::
k
FLOAT
)
&&
return
(
in
.
type
==
nvinfer1
::
DataType
::
k
HALF
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
else
{
return
in
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
...
...
@@ -364,7 +265,6 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
if
(
with_fp16_
)
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
...
...
@@ -373,8 +273,7 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
0
];
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
else
{
return
in
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
return
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
}
}
...
...
@@ -425,199 +324,6 @@ size_t FusedTokenPrunePluginDynamic::getWorkspaceSize(
return
size
;
}
template
<
typename
T
>
inline
void
enqueueImpl
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace_ptr
,
cudaStream_t
stream
,
int
device_id
,
T
max_value
,
bool
keep_first_token_
,
bool
keep_order_
)
{
// Dims
auto
attn_dims
=
input_desc
[
0
].
dims
;
auto
x_dims
=
input_desc
[
1
].
dims
;
auto
new_mask_dims
=
input_desc
[
3
].
dims
;
auto
bsz
=
attn_dims
.
d
[
0
],
nb_head
=
attn_dims
.
d
[
1
],
max_seq_len
=
attn_dims
.
d
[
2
];
auto
c
=
x_dims
.
d
[
2
];
auto
slimmed_x_len
=
new_mask_dims
.
d
[
2
];
// Inputs
const
T
*
attn_data
=
static_cast
<
const
T
*>
(
inputs
[
0
]);
const
T
*
x_data
=
static_cast
<
const
T
*>
(
inputs
[
1
]);
const
T
*
mask_data
=
static_cast
<
const
T
*>
(
inputs
[
2
]);
// Outputs
T
*
output_data
=
static_cast
<
T
*>
(
outputs
[
0
]);
int32_t
*
output_indices_data
=
static_cast
<
int32_t
*>
(
outputs
[
1
]);
int
total
=
bsz
*
nb_head
*
max_seq_len
*
max_seq_len
;
int
block
=
operators
::
ComputeBlockSize
(
max_seq_len
);
int
grid
=
operators
::
CeilDivide
(
total
,
block
);
// Workspace for intermediate variable
char
*
workspace
=
static_cast
<
char
*>
(
workspace_ptr
);
T
*
attn_tmp_data
=
reinterpret_cast
<
T
*>
(
workspace
);
size_t
offset
=
total
*
sizeof
(
T
);
T
*
attn_accu_data
=
reinterpret_cast
<
T
*>
(
workspace
+
offset
);
offset
+=
bsz
*
max_seq_len
*
sizeof
(
T
);
int32_t
*
attn_accu_indices_data
=
reinterpret_cast
<
int32_t
*>
(
workspace
+
offset
);
offset
+=
bsz
*
max_seq_len
*
sizeof
(
int32_t
);
T
*
sort_attn_accu_data
=
reinterpret_cast
<
T
*>
(
workspace
+
offset
);
offset
+=
bsz
*
max_seq_len
*
sizeof
(
T
);
int32_t
*
sort_attn_accu_indices_data
=
reinterpret_cast
<
int32_t
*>
(
workspace
+
offset
);
offset
+=
bsz
*
max_seq_len
*
sizeof
(
int32_t
);
int
*
offsets_data
=
reinterpret_cast
<
int
*>
(
workspace
+
offset
);
offset
+=
(
bsz
+
1
)
*
sizeof
(
int
);
int32_t
*
slimmed_sort_attn_accu_indices_data
=
reinterpret_cast
<
int32_t
*>
(
workspace
+
offset
);
// 1. Filter attn by mask
ElementwiseMask
<
T
>
<<<
grid
,
block
,
0
,
stream
>>>
(
attn_data
,
mask_data
,
attn_tmp_data
,
total
);
total
=
bsz
*
max_seq_len
;
block
=
operators
::
ComputeBlockSize
(
max_seq_len
);
grid
=
operators
::
CeilDivide
(
total
,
block
);
FillZero
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
attn_accu_data
,
total
);
// 2. Reduce sum
total
=
bsz
*
nb_head
*
max_seq_len
*
max_seq_len
;
int
block_tmp
=
max_seq_len
;
while
(
block_tmp
>
1024
)
block_tmp
/=
2
;
// if max seq len > 1024, it must be 2^n
block
=
block_tmp
;
// make sure max_seq_len is an integral multiple of block_size
grid
=
operators
::
CeilDivide
(
total
,
block
);
ReduceSum2
<
T
><<<
grid
,
block
,
block
*
sizeof
(
T
),
stream
>>>
(
attn_tmp_data
,
attn_accu_data
,
bsz
,
nb_head
,
max_seq_len
);
// 3. Prepare token indices
total
=
bsz
*
max_seq_len
;
block
=
operators
::
ComputeBlockSize
(
max_seq_len
);
grid
=
operators
::
CeilDivide
(
total
,
block
);
FillIndex
<<<
grid
,
block
,
0
,
stream
>>>
(
attn_accu_indices_data
,
bsz
,
max_seq_len
);
// 4. Sort token indices by attn
if
(
keep_first_token_
)
{
MaximumFirst
<
T
>
<<<
bsz
,
1
,
0
,
stream
>>>
(
attn_accu_data
,
bsz
,
max_seq_len
,
max_value
);
}
size_t
temp_storage_bytes
=
-
1
;
int
num_items
=
bsz
*
max_seq_len
;
int
num_segments
=
bsz
;
FillOffsets
<<<
bsz
+
1
,
1
,
0
,
stream
>>>
(
offsets_data
,
bsz
,
max_seq_len
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceSegmentedRadixSort
::
SortPairsDescending
(
nullptr
,
temp_storage_bytes
,
attn_accu_data
,
sort_attn_accu_data
,
attn_accu_indices_data
,
sort_attn_accu_indices_data
,
num_items
,
num_segments
,
offsets_data
,
offsets_data
+
1
,
0
,
sizeof
(
T
)
*
8
,
stream
));
int64_t
temp_size
=
temp_storage_bytes
;
phi
::
DenseTensor
temp_storage
;
auto
*
temp_storage_data
=
temp_storage
.
mutable_data
<
uint8_t
>
(
{
temp_size
},
platform
::
CUDAPlace
(
device_id
));
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceSegmentedRadixSort
::
SortPairsDescending
(
temp_storage_data
,
temp_storage_bytes
,
attn_accu_data
,
sort_attn_accu_data
,
attn_accu_indices_data
,
sort_attn_accu_indices_data
,
num_items
,
num_segments
,
offsets_data
,
offsets_data
+
1
,
0
,
sizeof
(
T
)
*
8
,
stream
));
// 5. Slice
total
=
bsz
*
slimmed_x_len
;
block
=
operators
::
ComputeBlockSize
(
slimmed_x_len
);
grid
=
operators
::
CeilDivide
(
total
,
block
);
Slice
<
int32_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
sort_attn_accu_indices_data
,
slimmed_sort_attn_accu_indices_data
,
bsz
,
max_seq_len
,
slimmed_x_len
);
if
(
keep_order_
)
{
// 6. reorder
num_items
=
bsz
*
slimmed_x_len
;
FillOffsets
<<<
bsz
+
1
,
1
,
0
,
stream
>>>
(
offsets_data
,
bsz
,
slimmed_x_len
);
temp_storage_bytes
=
-
1
;
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceSegmentedRadixSort
::
SortKeys
(
nullptr
,
temp_storage_bytes
,
slimmed_sort_attn_accu_indices_data
,
output_indices_data
,
num_items
,
num_segments
,
offsets_data
,
offsets_data
+
1
,
0
,
sizeof
(
int32_t
)
*
8
,
stream
));
temp_size
=
temp_storage_bytes
;
temp_storage
.
Resize
({
temp_size
});
temp_storage_data
=
temp_storage
.
mutable_data
<
uint8_t
>
(
platform
::
CUDAPlace
(
device_id
));
PADDLE_ENFORCE_GPU_SUCCESS
(
cub
::
DeviceSegmentedRadixSort
::
SortKeys
(
temp_storage_data
,
temp_storage_bytes
,
slimmed_sort_attn_accu_indices_data
,
output_indices_data
,
num_items
,
num_segments
,
offsets_data
,
offsets_data
+
1
,
0
,
sizeof
(
int32_t
)
*
8
,
stream
));
TakeAlongAxis
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
x_data
,
output_data
,
output_indices_data
,
bsz
,
max_seq_len
,
slimmed_x_len
,
c
);
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMemcpy
(
output_indices_data
,
slimmed_sort_attn_accu_indices_data
,
bsz
*
slimmed_x_len
*
sizeof
(
int32_t
),
cudaMemcpyDeviceToDevice
));
TakeAlongAxis
<
T
>
<<<
grid
,
block
,
0
,
stream
>>>
(
x_data
,
output_data
,
slimmed_sort_attn_accu_indices_data
,
bsz
,
max_seq_len
,
slimmed_x_len
,
c
);
}
}
int
FusedTokenPrunePluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
...
...
@@ -628,49 +334,56 @@ int FusedTokenPrunePluginDynamic::enqueue(
if
(
flag_varseqlen_
)
{
if
(
!
(
input_desc
[
0
].
type
==
nvinfer1
::
DataType
::
kHALF
&&
input_desc
[
1
].
type
==
nvinfer1
::
DataType
::
kHALF
))
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'type must half
"
));
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'type must half for varseqlen
"
));
}
float
scale
=
static_cast
<
float
>
(
input_desc
[
3
].
dims
.
d
[
2
])
/
input_desc
[
6
].
dims
.
d
[
1
];
const
int32_t
*
inputs5
=
static_cast
<
const
int32_t
*>
(
inputs
[
5
]);
// pre pos id
int32_t
*
outputs3
=
static_cast
<
int32_t
*>
(
outputs
[
3
]);
// new pos id
half
*
outputs0
=
static_cast
<
half
*>
(
outputs
[
0
]);
static_cast
<
float
>
(
input_desc
[
3
].
dims
.
d
[
2
])
/
input_desc
[
2
].
dims
.
d
[
2
];
const
int32_t
*
input5
=
static_cast
<
const
int32_t
*>
(
inputs
[
5
]);
// pre pos id
int32_t
*
output3
=
static_cast
<
int32_t
*>
(
outputs
[
3
]);
// new pos id
half
*
output0
=
static_cast
<
half
*>
(
outputs
[
0
]);
const
int32_t
B
=
input_desc
[
1
].
dims
.
d
[
0
];
// batchs
const
int32_t
max_sequnce_length
=
input_desc
[
1
].
dims
.
d
[
1
];
// max sequnce length
const
int32_t
length
=
input_desc
[
1
].
dims
.
d
[
2
];
//
vector length
const
int32_t
length
=
input_desc
[
1
].
dims
.
d
[
2
];
//
hidden size
const
half
*
scores
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
// reduce sum
const
half
*
tokens
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
const
int32_t
scores_size
=
B
*
max_sequnce_length
;
int32_t
padding_token_length
;
if
(
max_sequnce_length
<=
128
)
{
if
(
max_sequnce_length
<=
64
)
{
padding_token_length
=
64
;
}
else
if
(
max_sequnce_length
<=
128
)
{
padding_token_length
=
128
;
}
else
if
(
max_sequnce_length
<=
256
)
{
padding_token_length
=
256
;
}
else
if
(
max_sequnce_length
<=
384
)
{
padding_token_length
=
384
;
}
else
if
(
max_sequnce_length
<=
512
)
{
padding_token_length
=
512
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'token_length must <=
384
"
));
"Token_prune'token_length must <=
512
"
));
}
// 1. Compute the token length after pruning.
compute_token_length
<<<
1
,
B
,
0
,
stream
>>>
(
input
s
5
,
pruned_token_lengths_
,
scale
);
input5
,
pruned_token_lengths_
,
scale
);
fill_index_padding_score
<<<
B
,
padding_token_length
,
0
,
stream
>>>
(
token_index_
,
scores
,
scores_size
,
padding_scores_
);
// 2. Padding scores
fill_index_padding_score
<
half
><<<
B
,
padding_token_length
,
0
,
stream
>>>
(
token_index_
,
scores
,
max_sequnce_length
,
static_cast
<
half
*>
(
padding_scores_
));
// 3. compute new pos id
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
pruned_token_lengths_
,
output
s
3
,
output3
,
B
+
1
);
// Allocate temporary storage
cudaMalloc
(
&
d_temp_storage
,
temp_storage_bytes
);
...
...
@@ -679,20 +392,28 @@ int FusedTokenPrunePluginDynamic::enqueue(
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
pruned_token_lengths_
,
output
s
3
,
output3
,
B
+
1
);
if
(
padding_token_length
==
128
)
{
general_topk_pair_sort
<
half
,
32
,
4
>
<<<
B
,
32
,
0
,
stream
>>>
(
padding_scores_
,
token_index_
);
// 128
// 4. sort scores
if
(
padding_token_length
==
64
)
{
general_topk_pair_sort
<
half
,
32
,
2
><<<
B
,
32
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 64
}
else
if
(
padding_token_length
==
128
)
{
general_topk_pair_sort
<
half
,
32
,
4
><<<
B
,
32
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 128
}
else
if
(
padding_token_length
==
256
)
{
general_topk_pair_sort
<
half
,
64
,
4
>
<<<
B
,
64
,
0
,
stream
>>>
(
padding_scores_
,
token_index_
);
// 256
general_topk_pair_sort
<
half
,
64
,
4
><<<
B
,
64
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 256
}
else
if
(
padding_token_length
==
384
)
{
general_topk_pair_sort
<
half
,
96
,
4
><<<
B
,
96
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 384
}
else
{
general_topk_pair_sort
<
half
,
96
,
4
>
<<<
B
,
96
,
0
,
stream
>>>
(
padding_scores_
,
token_index_
);
// 384
general_topk_pair_sort
<
half
,
128
,
4
><<<
B
,
128
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 512
}
// 5. compute output
int32_t
num_threads
;
if
(
length
<
1024
)
{
num_threads
=
length
;
...
...
@@ -723,46 +444,196 @@ int FusedTokenPrunePluginDynamic::enqueue(
B
,
max_sequnce_length
,
length
/
num_threads
);
// batchs, max_sequnce_length, vector_ength/***
varlen_prune_token
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
output
s3
,
token_index_
,
outputs
0
);
varlen_prune_token
_change_order
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
output
3
,
padding_token_length
,
token_index_
,
output
0
);
}
else
{
auto
input_type
=
input_desc
[
0
].
type
;
auto
attn_dims
=
input_desc
[
0
].
dims
;
auto
bsz
=
attn_dims
.
d
[
0
],
nb_head
=
attn_dims
.
d
[
1
],
max_seq_len
=
attn_dims
.
d
[
2
];
int
device_id
;
cudaGetDevice
(
&
device_id
);
const
int32_t
B
=
input_desc
[
1
].
dims
.
d
[
0
];
// batchs
const
int32_t
pre_sequnce_length
=
input_desc
[
1
].
dims
.
d
[
1
];
const
int32_t
new_sequnce_length
=
input_desc
[
3
].
dims
.
d
[
2
];
// new mask
const
int32_t
length
=
input_desc
[
1
].
dims
.
d
[
2
];
// hidden size
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. FusedTokenPrune-->fp32"
;
const
float
*
scores
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
// reduce sum
const
float
*
tokens
=
static_cast
<
const
float
*>
(
inputs
[
1
]);
// X
float
*
output0
=
static_cast
<
float
*>
(
outputs
[
0
]);
int32_t
padding_token_length
;
if
(
pre_sequnce_length
<=
64
)
{
padding_token_length
=
64
;
}
else
if
(
pre_sequnce_length
<=
128
)
{
padding_token_length
=
128
;
}
else
if
(
pre_sequnce_length
<=
256
)
{
padding_token_length
=
256
;
}
else
if
(
pre_sequnce_length
<=
384
)
{
padding_token_length
=
384
;
}
else
if
(
pre_sequnce_length
<=
512
)
{
padding_token_length
=
512
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'token_length must <= 512"
));
}
float
max
=
std
::
numeric_limits
<
float
>::
max
();
enqueueImpl
<
float
>
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
,
device_id
,
max
,
keep_first_token_
,
keep_order_
);
// 1. Padding scores
fill_index_padding_score
<
float
><<<
B
,
padding_token_length
,
0
,
stream
>>>
(
token_index_
,
scores
,
pre_sequnce_length
,
static_cast
<
float
*>
(
padding_scores_
));
// 2. sort scores
if
(
padding_token_length
==
64
)
{
general_topk_pair_sort
<
float
,
32
,
2
><<<
B
,
32
,
0
,
stream
>>>
(
static_cast
<
float
*>
(
padding_scores_
),
token_index_
);
// 64
}
else
if
(
padding_token_length
==
128
)
{
general_topk_pair_sort
<
float
,
32
,
4
><<<
B
,
32
,
0
,
stream
>>>
(
static_cast
<
float
*>
(
padding_scores_
),
token_index_
);
// 128
}
else
if
(
padding_token_length
==
256
)
{
general_topk_pair_sort
<
float
,
64
,
4
><<<
B
,
64
,
0
,
stream
>>>
(
static_cast
<
float
*>
(
padding_scores_
),
token_index_
);
// 256
}
else
if
(
padding_token_length
==
384
)
{
general_topk_pair_sort
<
float
,
96
,
4
><<<
B
,
96
,
0
,
stream
>>>
(
static_cast
<
float
*>
(
padding_scores_
),
token_index_
);
// 384
}
else
{
general_topk_pair_sort
<
float
,
128
,
4
><<<
B
,
128
,
0
,
stream
>>>
(
static_cast
<
float
*>
(
padding_scores_
),
token_index_
);
// 512
}
// 3. compute output
int32_t
num_threads
;
if
(
length
<
1024
)
{
num_threads
=
length
;
}
else
{
if
(
length
%
512
==
0
)
{
num_threads
=
512
;
}
else
if
(
length
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
length
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
length
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
length
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
length
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
length
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
length
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
length
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
num_threads
=
1
;
}
}
if
(
keep_order_
)
{
const
dim3
num_blocks
(
B
,
length
/
num_threads
);
prune_token_keep_order
<
float
>
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
pre_sequnce_length
,
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
}
else
{
const
dim3
num_blocks
(
B
,
pre_sequnce_length
,
length
/
num_threads
);
prune_token_change_order
<
float
>
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
}
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. FusedTokenPrune-->fp16"
;
const
half
*
scores
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
// reduce sum
const
half
*
tokens
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
// X
half
*
output0
=
static_cast
<
half
*>
(
outputs
[
0
]);
int32_t
padding_token_length
;
if
(
pre_sequnce_length
<=
64
)
{
padding_token_length
=
64
;
}
else
if
(
pre_sequnce_length
<=
128
)
{
padding_token_length
=
128
;
}
else
if
(
pre_sequnce_length
<=
256
)
{
padding_token_length
=
256
;
}
else
if
(
pre_sequnce_length
<=
384
)
{
padding_token_length
=
384
;
}
else
if
(
pre_sequnce_length
<=
512
)
{
padding_token_length
=
512
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'token_length must <= 512"
));
}
// 1. Padding scores
fill_index_padding_score
<
half
><<<
B
,
padding_token_length
,
0
,
stream
>>>
(
token_index_
,
scores
,
pre_sequnce_length
,
static_cast
<
half
*>
(
padding_scores_
));
// 2. sort scores
if
(
padding_token_length
==
64
)
{
general_topk_pair_sort
<
half
,
32
,
2
><<<
B
,
32
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 64
}
else
if
(
padding_token_length
==
128
)
{
general_topk_pair_sort
<
half
,
32
,
4
><<<
B
,
32
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 128
}
else
if
(
padding_token_length
==
256
)
{
general_topk_pair_sort
<
half
,
64
,
4
><<<
B
,
64
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 256
}
else
if
(
padding_token_length
==
384
)
{
general_topk_pair_sort
<
half
,
96
,
4
><<<
B
,
96
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 384
}
else
{
general_topk_pair_sort
<
half
,
128
,
4
><<<
B
,
128
,
0
,
stream
>>>
(
static_cast
<
half
*>
(
padding_scores_
),
token_index_
);
// 512
}
half
max
=
65504.0
;
enqueueImpl
<
half
>
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
,
device_id
,
max
,
keep_first_token_
,
keep_order_
);
// 3. compute output
int32_t
num_threads
;
if
(
length
<
1024
)
{
num_threads
=
length
;
}
else
{
if
(
length
%
512
==
0
)
{
num_threads
=
512
;
}
else
if
(
length
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
length
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
length
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
length
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
length
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
length
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
length
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
length
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
num_threads
=
1
;
}
}
if
(
keep_order_
)
{
const
dim3
num_blocks
(
B
,
length
/
num_threads
);
prune_token_keep_order
<
half
>
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
pre_sequnce_length
,
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
}
else
{
const
dim3
num_blocks
(
B
,
pre_sequnce_length
,
length
/
num_threads
);
prune_token_change_order
<
half
>
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
new_sequnce_length
,
padding_token_length
,
token_index_
,
output0
);
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The FusedTokenPrune TRT Plugin's input type "
...
...
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h
浏览文件 @
65c17315
...
...
@@ -93,12 +93,33 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
int
nb_outputs
)
TRT_NOEXCEPT
override
{
max_batchs_
=
in
[
1
].
max
.
d
[
0
];
max_token_length_
=
in
[
1
].
max
.
d
[
1
];
int32_t
padding_token_length
;
if
(
max_token_length_
<=
64
)
{
padding_token_length
=
64
;
}
else
if
(
max_token_length_
<=
128
)
{
padding_token_length
=
128
;
}
else
if
(
max_token_length_
<=
256
)
{
padding_token_length
=
256
;
}
else
if
(
max_token_length_
<=
384
)
{
padding_token_length
=
384
;
}
else
if
(
max_token_length_
<=
512
)
{
padding_token_length
=
512
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'token_length(max) must <= 512"
));
}
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
pruned_token_lengths_
,
(
max_batchs_
+
1
)
*
sizeof
(
int32_t
)));
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
token_index_
,
max_batchs_
*
max_token_length_
*
sizeof
(
int32_t
)));
&
token_index_
,
max_batchs_
*
padding_token_length
*
sizeof
(
int32_t
)));
int32_t
type_size
=
4
;
if
(
in
[
0
].
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
)
{
type_size
=
2
;
}
else
{
type_size
=
4
;
}
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
padding_scores_
,
max_batchs_
*
max_token_length_
*
sizeof
(
half
)
));
&
padding_scores_
,
max_batchs_
*
padding_token_length
*
type_size
));
}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
...
...
@@ -129,7 +150,7 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
int32_t
*
token_index_
;
int32_t
max_batchs_
;
int32_t
max_token_length_
;
half
*
padding_scores_
;
void
*
padding_scores_
;
};
class
FusedTokenPrunePluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
...
...
paddle/fluid/inference/tensorrt/test_dynamic_engine.cc
浏览文件 @
65c17315
...
...
@@ -352,24 +352,24 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
ctx_
->
PartialInitWithAllocator
();
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
min_input_shape
=
{
{
"attn"
,
{
4
,
1
,
4
,
4
}},
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
=
{
{
"attn"
,
{
4
,
1
,
4
,
4
}},
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape
=
{
{
"attn"
,
{
4
,
1
,
4
,
4
}},
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
engine_
=
new
TensorRTEngine
(
16
,
1
<<
10
,
AnalysisConfig
::
Precision
::
k
Half
,
AnalysisConfig
::
Precision
::
k
Float32
,
nullptr
,
0
,
min_input_shape
,
...
...
@@ -391,7 +391,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
}
}
void
PrepareInputOutput
(
const
std
::
vector
<
std
::
vector
<
float
16
>>
inputs
,
void
PrepareInputOutput
(
const
std
::
vector
<
std
::
vector
<
float
>>
inputs
,
std
::
vector
<
std
::
vector
<
int
>>
output_shapes
)
{
LOG
(
INFO
)
<<
"PrepareInputOutput"
;
int
num_inputs
=
inputs
.
size
();
...
...
@@ -423,15 +423,15 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000)
tensorrt
::
plugin
::
TrtPluginRegistry
::
Global
()
->
RegistToTrt
();
auto
*
attn
=
engine_
->
DeclareInput
(
"attn"
,
nvinfer1
::
DataType
::
k
HALF
,
nvinfer1
::
Dims4
{
-
1
,
1
,
4
,
4
});
"attn"
,
nvinfer1
::
DataType
::
k
FLOAT
,
nvinfer1
::
Dims2
{
-
1
,
4
});
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
k
HALF
,
nvinfer1
::
Dims3
{
-
1
,
4
,
1
});
"x"
,
nvinfer1
::
DataType
::
k
FLOAT
,
nvinfer1
::
Dims3
{
-
1
,
4
,
1
});
auto
*
mask
=
engine_
->
DeclareInput
(
"mask"
,
nvinfer1
::
DataType
::
k
HALF
,
nvinfer1
::
Dims4
{
-
1
,
1
,
4
,
4
});
"mask"
,
nvinfer1
::
DataType
::
k
FLOAT
,
nvinfer1
::
Dims4
{
-
1
,
1
,
4
,
4
});
auto
*
new_mask
=
engine_
->
DeclareInput
(
"new_mask"
,
nvinfer1
::
DataType
::
k
HALF
,
nvinfer1
::
Dims4
{
-
1
,
1
,
2
,
2
});
"new_mask"
,
nvinfer1
::
DataType
::
k
FLOAT
,
nvinfer1
::
Dims4
{
-
1
,
1
,
2
,
2
});
plugin
::
FusedTokenPrunePluginDynamic
*
plugin
=
new
plugin
::
FusedTokenPrunePluginDynamic
(
tru
e
,
new
plugin
::
FusedTokenPrunePluginDynamic
(
/*with_fp16*/
fals
e
,
/*keep_first_token*/
false
,
/*keep_order*/
true
,
/*flag_varseqlen*/
false
);
...
...
@@ -449,18 +449,215 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
6
);
LOG
(
INFO
)
<<
"create input"
;
std
::
vector
<
float16
>
attn_v
(
64
);
std
::
vector
<
float
>
attn_v
(
16
);
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
attn_v
[
j
*
4
+
k
]
=
k
;
}
}
std
::
vector
<
float
>
x_v
(
16
);
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
x_v
[
i
*
4
+
j
]
=
4
-
j
;
}
}
std
::
vector
<
float
>
mask_v
(
64
);
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
attn_v
[
i
*
16
+
j
*
4
+
k
]
=
k
;
mask_v
[
i
*
16
+
j
*
4
+
k
]
=
1
;
}
}
}
std
::
vector
<
float
>
new_mask_v
(
16
);
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
for
(
int
k
=
0
;
k
<
2
;
++
k
)
{
new_mask_v
[
i
*
4
+
j
*
2
+
k
]
=
1
;
}
}
}
LOG
(
INFO
)
<<
"create output"
;
std
::
vector
<
int
>
out_slimmed_x_shape
{
4
,
2
,
1
};
std
::
vector
<
int
>
out_cls_ins_shape
{
4
,
2
};
PrepareInputOutput
({
attn_v
,
x_v
,
mask_v
,
new_mask_v
},
{
out_slimmed_x_shape
,
out_cls_ins_shape
});
auto
*
attn_gpu_data
=
inputs_
[
0
].
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
x_gpu_data
=
inputs_
[
1
].
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
mask_gpu_data
=
inputs_
[
2
].
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
new_mask_gpu_data
=
inputs_
[
3
].
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
slimmed_x_gpu_data
=
outputs_
[
0
].
mutable_data
<
float
>
(
ctx_
->
GetPlace
());
auto
*
cls_inds_gpu_data
=
outputs_
[
1
].
mutable_data
<
int32_t
>
(
ctx_
->
GetPlace
());
LOG
(
INFO
)
<<
"create buffers"
;
std
::
vector
<
void
*>
buffers
(
6
);
buffers
[
0
]
=
reinterpret_cast
<
void
*>
(
attn_gpu_data
);
buffers
[
1
]
=
reinterpret_cast
<
void
*>
(
x_gpu_data
);
buffers
[
2
]
=
reinterpret_cast
<
void
*>
(
mask_gpu_data
);
buffers
[
3
]
=
reinterpret_cast
<
void
*>
(
new_mask_gpu_data
);
buffers
[
4
]
=
reinterpret_cast
<
void
*>
(
slimmed_x_gpu_data
);
buffers
[
5
]
=
reinterpret_cast
<
void
*>
(
cls_inds_gpu_data
);
LOG
(
INFO
)
<<
"Execute"
;
engine_
->
Execute
(
4
,
&
buffers
,
ctx_
->
stream
());
std
::
vector
<
float
>
slimmed_x_v
(
8
);
std
::
vector
<
int32_t
>
cls_inds_v
;
LOG
(
INFO
)
<<
"GetOutput"
;
GetOutput
(
slimmed_x_v
,
cls_inds_v
);
// slimmed_x_v: [[4,3,2,1],[4,3,2,1],[4,3,2,1],[4,3,2,1]] ->
// [[2,1],[2,1],[2,1],[2,1]]
ASSERT_EQ
(
slimmed_x_v
[
0
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
1
],
1
);
ASSERT_EQ
(
slimmed_x_v
[
2
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
3
],
1
);
ASSERT_EQ
(
slimmed_x_v
[
4
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
5
],
1
);
ASSERT_EQ
(
slimmed_x_v
[
6
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
7
],
1
);
LOG
(
INFO
)
<<
"finish"
;
#endif
}
class
TensorRTDynamicTestFusedTokenPruneHalf
:
public
::
testing
::
Test
{
protected:
void
SetUp
()
override
{
ctx_
=
new
phi
::
GPUContext
(
platform
::
CUDAPlace
(
0
));
ctx_
->
SetAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
platform
::
CUDAPlace
(
0
),
ctx_
->
stream
())
.
get
());
ctx_
->
SetHostAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
paddle
::
platform
::
CPUPlace
())
.
get
());
ctx_
->
SetZeroAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetZeroAllocator
(
platform
::
CUDAPlace
(
0
))
.
get
());
ctx_
->
SetPinnedAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
paddle
::
platform
::
CUDAPinnedPlace
())
.
get
());
ctx_
->
PartialInitWithAllocator
();
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
min_input_shape
=
{
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
=
{
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape
=
{
{
"attn"
,
{
4
,
4
}},
{
"x"
,
{
4
,
4
,
1
}},
{
"mask"
,
{
4
,
1
,
4
,
4
}},
{
"new_mask"
,
{
4
,
1
,
2
,
2
}}};
engine_
=
new
TensorRTEngine
(
16
,
1
<<
10
,
AnalysisConfig
::
Precision
::
kHalf
,
nullptr
,
0
,
min_input_shape
,
max_input_shape
,
optim_input_shape
,
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(),
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(),
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
(),
false
,
phi
::
DataType
::
FLOAT16
,
NaiveLogger
::
Global
());
engine_
->
InitNetwork
();
}
void
TearDown
()
override
{
if
(
engine_
)
{
delete
engine_
;
engine_
=
nullptr
;
}
}
void
PrepareInputOutput
(
const
std
::
vector
<
std
::
vector
<
float16
>>
inputs
,
std
::
vector
<
std
::
vector
<
int
>>
output_shapes
)
{
LOG
(
INFO
)
<<
"PrepareInputOutput"
;
int
num_inputs
=
inputs
.
size
();
int
num_outputs
=
output_shapes
.
size
();
inputs_
.
resize
(
num_inputs
);
outputs_
.
resize
(
num_outputs
);
for
(
int
i
=
0
;
i
<
num_inputs
;
++
i
)
{
paddle
::
framework
::
TensorFromVector
(
inputs
[
i
],
*
ctx_
,
&
inputs_
[
i
]);
}
for
(
int
i
=
0
;
i
<
num_outputs
;
++
i
)
{
outputs_
[
i
].
Resize
(
phi
::
make_ddim
(
output_shapes
[
i
]));
}
}
void
GetOutput
(
std
::
vector
<
float
>
&
slimmed_x
,
// NOLINT
std
::
vector
<
int32_t
>
&
cls_inds
)
{
// NOLINT
paddle
::
framework
::
TensorToVector
(
outputs_
[
0
],
*
ctx_
,
&
slimmed_x
);
paddle
::
framework
::
TensorToVector
(
outputs_
[
1
],
*
ctx_
,
&
cls_inds
);
}
protected:
std
::
vector
<
phi
::
DenseTensor
>
inputs_
;
std
::
vector
<
phi
::
DenseTensor
>
outputs_
;
TensorRTEngine
*
engine_
;
phi
::
GPUContext
*
ctx_
;
};
TEST_F
(
TensorRTDynamicTestFusedTokenPruneHalf
,
test_fused_token_prune
)
{
#if IS_TRT_VERSION_GE(8000)
tensorrt
::
plugin
::
TrtPluginRegistry
::
Global
()
->
RegistToTrt
();
auto
*
attn
=
engine_
->
DeclareInput
(
"attn"
,
nvinfer1
::
DataType
::
kHALF
,
nvinfer1
::
Dims2
{
-
1
,
4
});
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kHALF
,
nvinfer1
::
Dims3
{
-
1
,
4
,
1
});
auto
*
mask
=
engine_
->
DeclareInput
(
"mask"
,
nvinfer1
::
DataType
::
kHALF
,
nvinfer1
::
Dims4
{
-
1
,
1
,
4
,
4
});
auto
*
new_mask
=
engine_
->
DeclareInput
(
"new_mask"
,
nvinfer1
::
DataType
::
kHALF
,
nvinfer1
::
Dims4
{
-
1
,
1
,
2
,
2
});
plugin
::
FusedTokenPrunePluginDynamic
*
plugin
=
new
plugin
::
FusedTokenPrunePluginDynamic
(
/*with_fp16*/
true
,
/*keep_first_token*/
false
,
/*keep_order*/
true
,
/*flag_varseqlen*/
false
);
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
=
{
attn
,
x
,
mask
,
new_mask
};
auto
*
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
4
,
plugin
);
PADDLE_ENFORCE_NOT_NULL
(
layer
,
platform
::
errors
::
InvalidArgument
(
"TRT fused_token_prune layer building failed."
));
std
::
vector
<
std
::
string
>
output_tensor_names
{
"out_slimmed_x"
,
"out_cls_inds"
};
for
(
size_t
i
=
0
;
i
<
2
;
i
++
)
{
layer
->
getOutput
(
i
)
->
setName
(
output_tensor_names
[
i
].
c_str
());
engine_
->
DeclareOutput
(
layer
,
i
,
output_tensor_names
[
i
]);
}
engine_
->
FreezeNetwork
();
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
6
);
LOG
(
INFO
)
<<
"create input"
;
std
::
vector
<
float16
>
attn_v
(
16
);
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
attn_v
[
j
*
4
+
k
]
=
k
;
}
}
std
::
vector
<
float16
>
x_v
(
16
);
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
x_v
[
i
*
4
+
j
]
=
1
;
x_v
[
i
*
4
+
j
]
=
4
-
j
;
}
}
std
::
vector
<
float16
>
mask_v
(
64
);
...
...
@@ -509,20 +706,24 @@ TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
engine_
->
Execute
(
4
,
&
buffers
,
ctx_
->
stream
());
std
::
vector
<
float
>
slimmed_x_v
;
std
::
vector
<
float
>
slimmed_x_v
(
8
)
;
std
::
vector
<
int32_t
>
cls_inds_v
;
LOG
(
INFO
)
<<
"GetOutput"
;
GetOutput
(
slimmed_x_v
,
cls_inds_v
);
ASSERT_EQ
(
cls_inds_v
[
0
],
2
);
ASSERT_EQ
(
cls_inds_v
[
1
],
3
);
ASSERT_EQ
(
cls_inds_v
[
2
],
2
);
ASSERT_EQ
(
cls_inds_v
[
3
],
3
);
ASSERT_EQ
(
cls_inds_v
[
4
],
2
);
ASSERT_EQ
(
cls_inds_v
[
5
],
3
);
ASSERT_EQ
(
cls_inds_v
[
6
],
2
);
ASSERT_EQ
(
cls_inds_v
[
7
],
3
);
// slimmed_x_v: [[4,3,2,1],[4,3,2,1],[4,3,2,1],[4,3,2,1]] ->
// [[2,1],[2,1],[2,1],[2,1]]
ASSERT_EQ
(
slimmed_x_v
[
0
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
1
],
1
);
ASSERT_EQ
(
slimmed_x_v
[
2
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
3
],
1
);
ASSERT_EQ
(
slimmed_x_v
[
4
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
5
],
1
);
ASSERT_EQ
(
slimmed_x_v
[
6
],
2
);
ASSERT_EQ
(
slimmed_x_v
[
7
],
1
);
LOG
(
INFO
)
<<
"finish"
;
#endif
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录