Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
29782728
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
29782728
编写于
11月 24, 2022
作者:
W
Wangzheee
提交者:
GitHub
11月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference]optimize token prune for Paddle-TensorRT (#48241)
* optimize token prune
上级
d39f3fb6
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
422 addition
and
185 deletion
+422
-185
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
...fluid/framework/ir/remove_padding_recover_padding_pass.cc
+57
-0
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
.../fluid/framework/ir/remove_padding_recover_padding_pass.h
+11
-1
paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc
.../fluid/inference/tensorrt/convert/fused_token_prune_op.cc
+14
-1
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
.../inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
+250
-121
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h
...d/inference/tensorrt/plugin/fused_token_prune_op_plugin.h
+27
-6
paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu
...fluid/inference/tensorrt/plugin/recover_padding_plugin.cu
+32
-27
paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
.../fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
+31
-26
paddle/fluid/inference/tensorrt/plugin/test_fused_token_prune_plugin.cc
...nference/tensorrt/plugin/test_fused_token_prune_plugin.cc
+0
-3
未找到文件。
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc
浏览文件 @
29782728
...
@@ -131,6 +131,21 @@ void Activation::operator()() {
...
@@ -131,6 +131,21 @@ void Activation::operator()() {
// Add links for activation op.
// Add links for activation op.
activation_op
->
LinksFrom
({
activation_input
}).
LinksTo
({
activation_out
});
activation_op
->
LinksFrom
({
activation_input
}).
LinksTo
({
activation_out
});
}
}
void
FusedTokenPrune
::
operator
()()
{
// Create nodes for fused_token_prune.
auto
*
fused_token_prune_input
=
pattern
->
NewNode
(
fused_token_prune_input_repr
())
->
assert_is_op_input
(
"fused_token_prune"
,
"X"
);
auto
*
fused_token_prune_op
=
pattern
->
NewNode
(
fused_token_prune_op_repr
())
->
assert_is_op
(
"fused_token_prune"
);
auto
*
fused_token_prune_output
=
pattern
->
NewNode
(
fused_token_prune_output_repr
())
->
assert_is_op_output
(
"fused_token_prune"
,
"SlimmedX"
);
fused_token_prune_op
->
LinksFrom
({
fused_token_prune_input
})
.
LinksTo
({
fused_token_prune_output
});
}
}
// namespace patterns
}
// namespace patterns
void
RemovePaddingRecoverPaddingPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
RemovePaddingRecoverPaddingPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
...
@@ -563,6 +578,48 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -563,6 +578,48 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
};
};
gpd6
(
graph
,
handler6
);
gpd6
(
graph
,
handler6
);
GraphPatternDetector
gpd7
;
patterns
::
FusedTokenPrune
fused_token_prune
(
gpd7
.
mutable_pattern
(),
"remove_padding_recover_padding_pass"
);
fused_token_prune
();
auto
handler7
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
3
)
<<
"remove_padding_recover_padding_pass for transformer: "
"fused_token_prune"
;
GET_IR_NODE_FROM_SUBGRAPH
(
fused_token_prune_input
,
fused_token_prune_input
,
fused_token_prune
);
GET_IR_NODE_FROM_SUBGRAPH
(
fused_token_prune_op
,
fused_token_prune_op
,
fused_token_prune
);
GET_IR_NODE_FROM_SUBGRAPH
(
fused_token_prune_output
,
fused_token_prune_output
,
fused_token_prune
);
std
::
vector
<
int64_t
>
fused_token_prune_input_shape
=
fused_token_prune_input
->
Var
()
->
GetShape
();
check_flag
=
true
;
if
(
fused_token_prune_input_shape
.
size
()
!=
multihead_matmul_input_shape
.
size
())
{
check_flag
=
false
;
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
for
(
size_t
i
=
0
;
i
<
fused_token_prune_input_shape
.
size
();
++
i
)
{
if
(
fused_token_prune_input_shape
[
i
]
!=
multihead_matmul_input_shape
[
i
])
{
check_flag
=
false
;
}
}
if
(
!
check_flag
)
{
VLOG
(
3
)
<<
"Transformer model remove_padding shape check failed, return "
"remove_padding pass."
;
return
;
}
insert_recover_padding_op
(
fused_token_prune_op
,
fused_token_prune_output
);
found_subgraph_count
++
;
};
gpd7
(
graph
,
handler7
);
AddStatis
(
found_subgraph_count
);
AddStatis
(
found_subgraph_count
);
}
}
...
...
paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h
浏览文件 @
29782728
...
@@ -95,7 +95,6 @@ struct Fc : public PatternBase {
...
@@ -95,7 +95,6 @@ struct Fc : public PatternBase {
PATTERN_DECL_NODE
(
fc_input
);
PATTERN_DECL_NODE
(
fc_input
);
PATTERN_DECL_NODE
(
fc_op
);
PATTERN_DECL_NODE
(
fc_op
);
PATTERN_DECL_NODE
(
fc_out
);
};
};
struct
Activation
:
public
PatternBase
{
struct
Activation
:
public
PatternBase
{
...
@@ -108,6 +107,17 @@ struct Activation : public PatternBase {
...
@@ -108,6 +107,17 @@ struct Activation : public PatternBase {
PATTERN_DECL_NODE
(
activation_op
);
PATTERN_DECL_NODE
(
activation_op
);
PATTERN_DECL_NODE
(
activation_out
);
PATTERN_DECL_NODE
(
activation_out
);
};
};
struct
FusedTokenPrune
:
public
PatternBase
{
FusedTokenPrune
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"fused_token_prune"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
fused_token_prune_input
);
PATTERN_DECL_NODE
(
fused_token_prune_op
);
PATTERN_DECL_NODE
(
fused_token_prune_output
);
};
}
// namespace patterns
}
// namespace patterns
class
RemovePaddingRecoverPaddingPass
:
public
FusePassBase
{
class
RemovePaddingRecoverPaddingPass
:
public
FusePassBase
{
...
...
paddle/fluid/inference/tensorrt/convert/fused_token_prune_op.cc
浏览文件 @
29782728
...
@@ -52,8 +52,21 @@ class FusedTokenPruneOpConverter : public OpConverter {
...
@@ -52,8 +52,21 @@ class FusedTokenPruneOpConverter : public OpConverter {
auto
*
word_id
=
engine_
->
GetITensor
(
"word_id"
);
auto
*
word_id
=
engine_
->
GetITensor
(
"word_id"
);
auto
*
pos_id
=
engine_
->
GetITensor
(
"pos_id"
);
auto
*
pos_id
=
engine_
->
GetITensor
(
"pos_id"
);
auto
*
mask_id
=
engine_
->
GetITensor
(
"mask_id"
);
auto
*
mask_id
=
engine_
->
GetITensor
(
"mask_id"
);
// reduce_sum: (-1,headsize,token_length,token_length) ->
// (-1,token_length)
uint32_t
reduce_dim
=
0
;
reduce_dim
|=
1
<<
1
;
// 00000000000000000000000000000010
reduce_dim
|=
1
<<
2
;
// 00000000000000000000000000000110
bool
keep_dim
=
false
;
nvinfer1
::
ReduceOperation
reduce_type
=
nvinfer1
::
ReduceOperation
::
kSUM
;
auto
*
reduce_sum_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
Attn
,
reduce_type
,
reduce_dim
,
keep_dim
);
// reduce_sum_layer->getOutput(0)->setType(reduce_sum_layer->getInput(0)->getType());
auto
*
Reduced
=
reduce_sum_layer
->
getOutput
(
0
);
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
=
{
std
::
vector
<
nvinfer1
::
ITensor
*>
itensors
=
{
Attn
,
X
,
Mask
,
NewMask
,
word_id
,
pos_id
,
mask_id
};
Reduced
,
X
,
Mask
,
NewMask
,
word_id
,
pos_id
,
mask_id
};
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
7
,
plugin
);
layer
=
engine_
->
AddDynamicPlugin
(
itensors
.
data
(),
7
,
plugin
);
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
layer
->
getOutput
(
0
)
->
setName
(
output_name
.
c_str
());
...
...
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.cu
浏览文件 @
29782728
...
@@ -31,19 +31,15 @@ namespace inference {
...
@@ -31,19 +31,15 @@ namespace inference {
namespace
tensorrt
{
namespace
tensorrt
{
namespace
plugin
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ElementwiseMask
(
const
T
*
a
,
__global__
void
ElementwiseMask
(
const
T
*
a
,
const
T
*
b
,
const
T
*
b
,
T
*
res
,
T
*
res
,
int
num_elements
)
{
int
num_elements
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
auto
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
auto
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
>=
num_elements
)
return
;
if
(
tid
>=
num_elements
)
return
;
const
T
zero
=
0
;
const
T
zero
=
0
;
res
[
tid
]
=
b
[
tid
]
>=
zero
?
a
[
tid
]
:
zero
;
res
[
tid
]
=
b
[
tid
]
>=
zero
?
a
[
tid
]
:
zero
;
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -123,7 +119,6 @@ __global__ void ReduceSum2(
...
@@ -123,7 +119,6 @@ __global__ void ReduceSum2(
template
<
>
template
<
>
__global__
void
ReduceSum2
<
half
>
(
__global__
void
ReduceSum2
<
half
>
(
const
half
*
src
,
half
*
dst
,
int
bsz
,
int
nb_head
,
int
max_seq_len
)
{
const
half
*
src
,
half
*
dst
,
int
bsz
,
int
nb_head
,
int
max_seq_len
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
bid
=
blockIdx
.
x
;
int
bid
=
blockIdx
.
x
;
int
num_blocks_per_head
=
((
max_seq_len
/
blockDim
.
x
)
*
max_seq_len
);
int
num_blocks_per_head
=
((
max_seq_len
/
blockDim
.
x
)
*
max_seq_len
);
...
@@ -155,7 +150,6 @@ __global__ void ReduceSum2<half>(
...
@@ -155,7 +150,6 @@ __global__ void ReduceSum2<half>(
static_cast
<
size_t
>
(
bsz
*
max_seq_len
),
static_cast
<
size_t
>
(
bsz
*
max_seq_len
),
static_cast
<
platform
::
float16
>
(
res_half
[
0
]));
static_cast
<
platform
::
float16
>
(
res_half
[
0
]));
}
}
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -177,14 +171,81 @@ __global__ void TakeAlongAxis(const T* src,
...
@@ -177,14 +171,81 @@ __global__ void TakeAlongAxis(const T* src,
}
}
}
}
__global__
void
pos_id_prune_kernel
(
const
int32_t
*
src
,
__global__
void
compute_token_length
(
const
int32_t
*
src
,
int32_t
*
dst
,
int32_t
*
dst
,
int
pos_nums
,
float
scale
)
{
float
scale
)
{
int32_t
it
=
threadIdx
.
x
;
dst
[
0
]
=
0
;
dst
[
it
]
=
max
(
static_cast
<
int
>
((
src
[
it
+
1
]
-
src
[
it
])
*
scale
),
1
);
for
(
int
i
=
1
;
i
<
pos_nums
;
i
++
)
{
}
dst
[
i
]
=
dst
[
i
-
1
]
+
max
(
static_cast
<
int
>
((
src
[
i
]
-
src
[
i
-
1
])
*
scale
),
2
);
__global__
void
fill_index_padding_score
(
int32_t
*
token_index
,
const
half
*
scores
,
int32_t
scores_size
,
half
*
padding_scores
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
token_index
[
tid
]
=
threadIdx
.
x
;
if
(
tid
<
scores_size
)
{
padding_scores
[
tid
]
=
scores
[
tid
];
}
else
{
padding_scores
[
tid
]
=
0
;
}
}
template
<
typename
T
,
int
BLOCK_THREADS
,
int
ITEMS_PER_THREAD
>
__global__
void
general_topk_pair_sort
(
T
*
in_keys
,
int32_t
*
in_out_values
)
{
typedef
cub
::
BlockRadixSort
<
T
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
int
>
BlockRadixSort
;
typedef
cub
::
BlockLoad
<
T
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_LOAD_TRANSPOSE
>
BlockLoadKey
;
typedef
cub
::
BlockLoad
<
int
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_LOAD_TRANSPOSE
>
BlockLoadValue
;
typedef
cub
::
BlockStore
<
T
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_STORE_TRANSPOSE
>
BlockStoreKey
;
typedef
cub
::
BlockStore
<
int
,
BLOCK_THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_STORE_TRANSPOSE
>
BlockStoreValue
;
__shared__
union
{
typename
BlockRadixSort
::
TempStorage
sort
;
typename
BlockLoadKey
::
TempStorage
loadkey
;
typename
BlockLoadValue
::
TempStorage
loadvalue
;
typename
BlockStoreKey
::
TempStorage
storekey
;
typename
BlockStoreValue
::
TempStorage
storevalue
;
}
temp_storage
;
int
block_offset
=
blockIdx
.
x
*
BLOCK_THREADS
*
ITEMS_PER_THREAD
;
T
thread_keys
[
ITEMS_PER_THREAD
];
int
thread_values
[
ITEMS_PER_THREAD
];
BlockLoadKey
(
temp_storage
.
loadkey
).
Load
(
in_keys
+
block_offset
,
thread_keys
);
BlockLoadValue
(
temp_storage
.
loadvalue
)
.
Load
(
in_out_values
+
block_offset
,
thread_values
);
__syncthreads
();
BlockRadixSort
(
temp_storage
.
sort
).
SortDescending
(
thread_keys
,
thread_values
);
__syncthreads
();
BlockStoreValue
(
temp_storage
.
storevalue
)
.
Store
(
in_out_values
+
block_offset
,
thread_values
);
}
__global__
void
varlen_prune_token
(
const
half
*
tokens
,
const
int32_t
*
token_pos
,
const
int32_t
*
token_index
,
half
*
output
)
{
int
batch
=
blockIdx
.
x
;
int
token_it
=
batch
*
gridDim
.
y
+
blockIdx
.
y
;
int
pre_value_it
=
token_it
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
token_index
[
token_it
]
<
token_pos
[
batch
+
1
]
-
token_pos
[
batch
])
{
output
[(
token_index
[
token_it
]
+
token_pos
[
batch
])
*
gridDim
.
z
*
blockDim
.
x
+
blockIdx
.
z
*
blockDim
.
x
+
threadIdx
.
x
]
=
tokens
[
pre_value_it
];
}
}
}
}
...
@@ -195,9 +256,29 @@ nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions(
...
@@ -195,9 +256,29 @@ nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions(
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
auto
x_dims
=
inputs
[
1
],
new_mask_dims
=
inputs
[
3
];
auto
x_dims
=
inputs
[
1
],
new_mask_dims
=
inputs
[
3
];
if
(
flag_varseqlen_
)
{
if
(
flag_varseqlen_
)
{
// max sum of seqlen: ceil(sum / scale) + n -1 >= for(i=0;i<n;i++) {sum +=
// floor(num(i) / scale)} auto
// pruned_sum_length=std::ceil(inputs[4].d[0]*new_mask_dims.d[2]/inputs[6].d[1])+
// inputs[1].d[0] - 1;
auto
pruned_sum_length
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUB
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kCEIL_DIV
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
inputs
[
4
].
d
[
0
],
*
new_mask_dims
.
d
[
2
]),
*
inputs
[
6
].
d
[
1
]),
*
inputs
[
1
].
d
[
0
]),
*
expr_builder
.
constant
(
1
));
if
(
output_index
==
0
)
{
if
(
output_index
==
0
)
{
nvinfer1
::
DimsExprs
ret
=
x_dims
;
nvinfer1
::
DimsExprs
ret
;
ret
.
d
[
1
]
=
new_mask_dims
.
d
[
2
];
ret
.
nbDims
=
4
;
ret
.
d
[
0
]
=
pruned_sum_length
;
ret
.
d
[
1
]
=
x_dims
.
d
[
2
];
ret
.
d
[
2
]
=
expr_builder
.
constant
(
1
);
ret
.
d
[
3
]
=
expr_builder
.
constant
(
1
);
return
ret
;
return
ret
;
}
else
if
(
output_index
==
1
)
{
}
else
if
(
output_index
==
1
)
{
nvinfer1
::
DimsExprs
ret
;
nvinfer1
::
DimsExprs
ret
;
...
@@ -209,18 +290,7 @@ nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions(
...
@@ -209,18 +290,7 @@ nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions(
// word id
// word id
nvinfer1
::
DimsExprs
ret
;
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
1
;
ret
.
nbDims
=
1
;
// max sum of seqlen: pre_seqlen * new_mask[2] / mask[1] + 2 * batchs
ret
.
d
[
0
]
=
pruned_sum_length
;
const
auto
*
two
=
expr_builder
.
constant
(
2
);
ret
.
d
[
0
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUM
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kFLOOR_DIV
,
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
inputs
[
4
].
d
[
0
],
*
new_mask_dims
.
d
[
2
]),
*
inputs
[
6
].
d
[
1
]),
*
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
two
,
*
inputs
[
6
].
d
[
0
]));
return
ret
;
return
ret
;
}
else
if
(
output_index
==
3
)
{
}
else
if
(
output_index
==
3
)
{
// pos id
// pos id
...
@@ -269,26 +339,18 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
...
@@ -269,26 +339,18 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
flag_varseqlen_
)
{
if
(
flag_varseqlen_
)
{
if
(
pos
==
0
)
{
if
(
pos
<=
3
||
pos
==
7
)
{
if
(
with_fp16_
)
{
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return
(
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#else
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#endif
}
else
{
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
"The FusedTokenPrune TRT Plugin's input type "
"should be half for varseqlen."
));
}
}
}
else
if
(
pos
<=
3
||
pos
==
7
)
{
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
0
];
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
else
if
(
pos
==
6
||
pos
==
11
)
{
// mask_id, mask_id_out
}
else
if
(
pos
==
6
||
pos
==
11
)
{
// mask_id, mask_id_out
return
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
&&
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
)
;
}
else
{
}
else
{
return
in
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
return
in
.
type
==
nvinfer1
::
DataType
::
kINT32
&&
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
...
@@ -296,14 +358,9 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
...
@@ -296,14 +358,9 @@ bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
}
else
{
}
else
{
if
(
pos
==
0
)
{
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return
(
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#else
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#endif
}
else
{
}
else
{
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
...
@@ -324,9 +381,9 @@ nvinfer1::DataType FusedTokenPrunePluginDynamic::getOutputDataType(
...
@@ -324,9 +381,9 @@ nvinfer1::DataType FusedTokenPrunePluginDynamic::getOutputDataType(
int
nb_inputs
)
const
TRT_NOEXCEPT
{
int
nb_inputs
)
const
TRT_NOEXCEPT
{
if
(
flag_varseqlen_
)
{
if
(
flag_varseqlen_
)
{
if
(
index
==
0
)
{
if
(
index
==
0
)
{
return
input_types
[
1
]
;
return
nvinfer1
::
DataType
::
kHALF
;
}
else
if
(
index
==
4
)
{
}
else
if
(
index
==
4
)
{
// mask id
return
nvinfer1
::
DataType
::
kFLOAT
;
return
input_types
[
6
]
;
}
else
{
}
else
{
// index = 1,2,3
// index = 1,2,3
return
nvinfer1
::
DataType
::
kINT32
;
return
nvinfer1
::
DataType
::
kINT32
;
...
@@ -557,14 +614,6 @@ inline void enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc,
...
@@ -557,14 +614,6 @@ inline void enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc,
}
}
}
}
inline
void
pos_id_prune
(
const
int32_t
*
input
,
int32_t
*
output
,
int
pos_nums
,
float
scale
,
cudaStream_t
stream
)
{
pos_id_prune_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
input
,
output
,
pos_nums
,
scale
);
}
int
FusedTokenPrunePluginDynamic
::
enqueue
(
int
FusedTokenPrunePluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
...
@@ -572,73 +621,153 @@ int FusedTokenPrunePluginDynamic::enqueue(
...
@@ -572,73 +621,153 @@ int FusedTokenPrunePluginDynamic::enqueue(
void
*
const
*
outputs
,
void
*
const
*
outputs
,
void
*
workspace
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
input_type
=
input_desc
[
0
].
type
;
auto
attn_dims
=
input_desc
[
0
].
dims
;
auto
bsz
=
attn_dims
.
d
[
0
],
nb_head
=
attn_dims
.
d
[
1
],
max_seq_len
=
attn_dims
.
d
[
2
];
int
device_id
;
cudaGetDevice
(
&
device_id
);
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. FusedTokenPrune-->fp32"
;
float
max
=
std
::
numeric_limits
<
float
>::
max
();
enqueueImpl
<
float
>
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
,
device_id
,
max
,
keep_first_token_
,
keep_order_
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG
(
1
)
<<
"TRT Plugin DataType selected. FusedTokenPrune-->fp16"
;
half
max
=
65504.0
;
enqueueImpl
<
half
>
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
,
device_id
,
max
,
keep_first_token_
,
keep_order_
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) TensorRT Plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.SetTRTDynamicShapeInfo(min_input_shape, "
"max_input_shape, opt_input_shape, true"
));
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The FusedTokenPrune TRT Plugin's input type "
"should be float or half."
));
}
if
(
flag_varseqlen_
)
{
if
(
flag_varseqlen_
)
{
if
(
!
(
input_desc
[
0
].
type
==
nvinfer1
::
DataType
::
kHALF
&&
input_desc
[
1
].
type
==
nvinfer1
::
DataType
::
kHALF
))
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'type must half"
));
}
float
scale
=
float
scale
=
static_cast
<
float
>
(
input_desc
[
3
].
dims
.
d
[
2
])
/
input_desc
[
6
].
dims
.
d
[
1
];
static_cast
<
float
>
(
input_desc
[
3
].
dims
.
d
[
2
])
/
input_desc
[
6
].
dims
.
d
[
1
];
// outputs[2]=inputs[4]; // word_id
const
int32_t
*
inputs5
=
const
int32_t
*
inputs5
=
static_cast
<
const
int32_t
*>
(
inputs
[
5
]);
static_cast
<
const
int32_t
*>
(
inputs
[
5
]);
// pre pos id
int32_t
*
outputs3
=
static_cast
<
int32_t
*>
(
outputs
[
3
]);
int32_t
*
outputs3
=
static_cast
<
int32_t
*>
(
outputs
[
3
]);
// new pos id
pos_id_prune
(
half
*
outputs0
=
static_cast
<
half
*>
(
outputs
[
0
]);
inputs5
,
outputs3
,
input_desc
[
5
].
dims
.
d
[
0
],
scale
,
stream
);
// pos_id
// outputs[4]=inputs[6]; // new_mask
const
int32_t
B
=
input_desc
[
1
].
dims
.
d
[
0
];
// batchs
const
int32_t
max_sequnce_length
=
input_desc
[
1
].
dims
.
d
[
1
];
// max sequnce length
const
int32_t
length
=
input_desc
[
1
].
dims
.
d
[
2
];
// vector length
const
half
*
scores
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
// reduce sum
const
half
*
tokens
=
static_cast
<
const
half
*>
(
inputs
[
1
]);
const
int32_t
scores_size
=
B
*
max_sequnce_length
;
int32_t
padding_token_length
;
if
(
max_sequnce_length
<=
128
)
{
padding_token_length
=
128
;
}
else
if
(
max_sequnce_length
<=
256
)
{
padding_token_length
=
256
;
}
else
if
(
max_sequnce_length
<=
384
)
{
padding_token_length
=
384
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Token_prune'token_length must <= 384"
));
}
// 1. Compute the token length after pruning.
compute_token_length
<<<
1
,
B
,
0
,
stream
>>>
(
inputs5
,
pruned_token_lengths_
,
scale
);
fill_index_padding_score
<<<
B
,
padding_token_length
,
0
,
stream
>>>
(
token_index_
,
scores
,
scores_size
,
padding_scores_
);
// Determine temporary device storage requirements
void
*
d_temp_storage
=
NULL
;
size_t
temp_storage_bytes
=
0
;
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
pruned_token_lengths_
,
outputs3
,
B
+
1
);
// Allocate temporary storage
cudaMalloc
(
&
d_temp_storage
,
temp_storage_bytes
);
// Run exclusive prefix sum
cub
::
DeviceScan
::
ExclusiveSum
(
d_temp_storage
,
temp_storage_bytes
,
pruned_token_lengths_
,
outputs3
,
B
+
1
);
if
(
padding_token_length
==
128
)
{
general_topk_pair_sort
<
half
,
32
,
4
>
<<<
B
,
32
,
0
,
stream
>>>
(
padding_scores_
,
token_index_
);
// 128
}
else
if
(
padding_token_length
==
256
)
{
general_topk_pair_sort
<
half
,
64
,
4
>
<<<
B
,
64
,
0
,
stream
>>>
(
padding_scores_
,
token_index_
);
// 256
}
else
{
general_topk_pair_sort
<
half
,
96
,
4
>
<<<
B
,
96
,
0
,
stream
>>>
(
padding_scores_
,
token_index_
);
// 384
}
int32_t
num_threads
;
if
(
length
<
1024
)
{
num_threads
=
length
;
}
else
{
if
(
length
%
512
==
0
)
{
num_threads
=
512
;
}
else
if
(
length
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
length
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
length
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
length
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
length
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
length
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
length
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
length
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
num_threads
=
1
;
}
}
const
dim3
num_blocks
(
B
,
max_sequnce_length
,
length
/
num_threads
);
// batchs, max_sequnce_length, vector_ength/***
varlen_prune_token
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
tokens
,
outputs3
,
token_index_
,
outputs0
);
}
else
{
auto
input_type
=
input_desc
[
0
].
type
;
auto
attn_dims
=
input_desc
[
0
].
dims
;
auto
bsz
=
attn_dims
.
d
[
0
],
nb_head
=
attn_dims
.
d
[
1
],
max_seq_len
=
attn_dims
.
d
[
2
];
int
device_id
;
cudaGetDevice
(
&
device_id
);
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. FusedTokenPrune-->fp32"
;
float
max
=
std
::
numeric_limits
<
float
>::
max
();
enqueueImpl
<
float
>
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
,
device_id
,
max
,
keep_first_token_
,
keep_order_
);
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. FusedTokenPrune-->fp16"
;
half
max
=
65504.0
;
enqueueImpl
<
half
>
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
,
device_id
,
max
,
keep_first_token_
,
keep_order_
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The FusedTokenPrune TRT Plugin's input type "
"should be float or half."
));
}
}
}
return
cudaGetLastError
()
!=
cudaSuccess
;
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
#endif
}
// namespace plugin
}
// namespace plugin
}
// namespace tensorrt
}
// namespace tensorrt
}
// namespace inference
}
// namespace inference
...
...
paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h
浏览文件 @
29782728
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -30,11 +31,10 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
...
@@ -30,11 +31,10 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
bool
keep_first_token
,
bool
keep_first_token
,
bool
keep_order
,
bool
keep_order
,
bool
flag_varseqlen
)
bool
flag_varseqlen
)
:
keep_first_token_
(
keep_first_token
),
:
with_fp16_
(
with_fp16
),
keep_first_token_
(
keep_first_token
),
keep_order_
(
keep_order
),
keep_order_
(
keep_order
),
flag_varseqlen_
(
flag_varseqlen
)
{
flag_varseqlen_
(
flag_varseqlen
)
{}
with_fp16_
=
with_fp16
;
}
FusedTokenPrunePluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{
FusedTokenPrunePluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
keep_first_token_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
keep_first_token_
);
...
@@ -42,8 +42,14 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
...
@@ -42,8 +42,14 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
flag_varseqlen_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
flag_varseqlen_
);
}
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
return
new
FusedTokenPrunePluginDynamic
(
FusedTokenPrunePluginDynamic
*
ptr
=
new
FusedTokenPrunePluginDynamic
(
with_fp16_
,
keep_first_token_
,
keep_order_
,
flag_varseqlen_
);
with_fp16_
,
keep_first_token_
,
keep_order_
,
flag_varseqlen_
);
ptr
->
max_batchs_
=
max_batchs_
;
ptr
->
max_token_length_
=
max_token_length_
;
ptr
->
pruned_token_lengths_
=
pruned_token_lengths_
;
ptr
->
token_index_
=
token_index_
;
ptr
->
padding_scores_
=
padding_scores_
;
return
ptr
;
}
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
...
@@ -84,7 +90,16 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
...
@@ -84,7 +90,16 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nb_inputs
,
int
nb_inputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nb_outputs
)
TRT_NOEXCEPT
override
{}
int
nb_outputs
)
TRT_NOEXCEPT
override
{
max_batchs_
=
in
[
1
].
max
.
d
[
0
];
max_token_length_
=
in
[
1
].
max
.
d
[
1
];
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
pruned_token_lengths_
,
(
max_batchs_
+
1
)
*
sizeof
(
int32_t
)));
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
token_index_
,
max_batchs_
*
max_token_length_
*
sizeof
(
int32_t
)));
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMalloc
(
&
padding_scores_
,
max_batchs_
*
max_token_length_
*
sizeof
(
half
)));
}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nb_inputs
,
int
nb_inputs
,
...
@@ -106,9 +121,15 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
...
@@ -106,9 +121,15 @@ class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
private:
private:
bool
with_fp16_
;
bool
keep_first_token_
;
bool
keep_first_token_
;
bool
keep_order_
;
bool
keep_order_
;
bool
flag_varseqlen_
;
bool
flag_varseqlen_
;
int32_t
*
pruned_token_lengths_
;
int32_t
*
token_index_
;
int32_t
max_batchs_
;
int32_t
max_token_length_
;
half
*
padding_scores_
;
};
};
class
FusedTokenPrunePluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
class
FusedTokenPrunePluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
...
...
paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu
浏览文件 @
29782728
...
@@ -19,9 +19,9 @@ namespace inference {
...
@@ -19,9 +19,9 @@ namespace inference {
namespace
tensorrt
{
namespace
tensorrt
{
namespace
plugin
{
namespace
plugin
{
__global__
void
RecoverPaddingKernel
(
const
float
*
input0
,
__global__
void
RecoverPaddingKernel
(
const
half
*
input0
,
const
int32_t
*
input1
,
const
int32_t
*
input1
,
float
*
output
)
{
half
*
output
)
{
int
word_id
=
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
;
int
word_id
=
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
;
int32_t
seqence_length
=
input1
[
blockIdx
.
x
+
1
]
-
input1
[
blockIdx
.
x
];
int32_t
seqence_length
=
input1
[
blockIdx
.
x
+
1
]
-
input1
[
blockIdx
.
x
];
if
(
blockIdx
.
y
<
seqence_length
)
{
if
(
blockIdx
.
y
<
seqence_length
)
{
...
@@ -79,7 +79,7 @@ bool RecoverPaddingPlugin::supportsFormatCombination(
...
@@ -79,7 +79,7 @@ bool RecoverPaddingPlugin::supportsFormatCombination(
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
&&
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kFLOAT
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
else
{
}
else
{
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
k
FLOAT
&&
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
k
HALF
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
}
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
...
@@ -114,38 +114,43 @@ int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
...
@@ -114,38 +114,43 @@ int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const
auto
input0_desc
=
inputDesc
[
0
];
const
auto
input0_desc
=
inputDesc
[
0
];
const
auto
input1_desc
=
inputDesc
[
1
];
const
auto
input1_desc
=
inputDesc
[
1
];
const
auto
input2_desc
=
inputDesc
[
2
];
const
auto
input2_desc
=
inputDesc
[
2
];
const
float
*
input0
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
const
half
*
input0
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
int32_t
*
input1
=
const
int32_t
*
input1
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
// pos_id_tensor
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
// pos_id_tensor
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
const
int32_t
vector_length
=
input0_desc
.
dims
.
d
[
1
];
int32_t
num_threads
;
int32_t
num_threads
;
if
(
input0_desc
.
dims
.
d
[
1
]
%
512
==
0
)
{
if
(
vector_length
<
1024
)
{
num_threads
=
512
;
num_threads
=
vector_length
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
input0_desc
.
dims
.
d
[
1
]
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
}
else
{
num_threads
=
1
;
if
(
vector_length
%
512
==
0
)
{
num_threads
=
512
;
}
else
if
(
vector_length
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
vector_length
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
vector_length
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
vector_length
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
vector_length
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
vector_length
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
vector_length
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
vector_length
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
num_threads
=
1
;
}
}
}
const
dim3
num_blocks
(
const
dim3
num_blocks
(
input1_desc
.
dims
.
d
[
0
]
-
1
,
input1_desc
.
dims
.
d
[
0
]
-
1
,
input2_desc
.
dims
.
d
[
1
],
input2_desc
.
dims
.
d
[
1
],
input0_desc
.
dims
.
d
[
1
]
/
num_threads
);
// batchs, max sequnce length
vector_length
/
num_threads
);
// batchs, max sequnce length
// (mask_id.dims.d[1]),
// (mask_id.dims.d[1]),
// input.dims.d[1]/256
// input.dims.d[1]/***
RecoverPaddingKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
RecoverPaddingKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
input0
,
input1
,
output
);
input0
,
input1
,
output
);
return
cudaGetLastError
()
!=
cudaSuccess
;
return
cudaGetLastError
()
!=
cudaSuccess
;
...
...
paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu
浏览文件 @
29782728
...
@@ -19,9 +19,9 @@ namespace inference {
...
@@ -19,9 +19,9 @@ namespace inference {
namespace
tensorrt
{
namespace
tensorrt
{
namespace
plugin
{
namespace
plugin
{
__global__
void
RemovePaddingKernel
(
const
float
*
input0
,
__global__
void
RemovePaddingKernel
(
const
half
*
input0
,
const
int32_t
*
input1
,
const
int32_t
*
input1
,
float
*
output
)
{
half
*
output
)
{
int
word_id
=
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
;
int
word_id
=
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
;
int32_t
seqence_length
=
input1
[
blockIdx
.
x
+
1
]
-
input1
[
blockIdx
.
x
];
int32_t
seqence_length
=
input1
[
blockIdx
.
x
+
1
]
-
input1
[
blockIdx
.
x
];
if
(
blockIdx
.
y
<
seqence_length
)
{
if
(
blockIdx
.
y
<
seqence_length
)
{
...
@@ -73,7 +73,7 @@ bool RemovePaddingPlugin::supportsFormatCombination(
...
@@ -73,7 +73,7 @@ bool RemovePaddingPlugin::supportsFormatCombination(
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
else
{
}
else
{
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
k
FLOAT
&&
return
inOut
[
pos
].
type
==
nvinfer1
::
DataType
::
k
HALF
&&
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
inOut
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
;
}
}
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
...
@@ -106,38 +106,43 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
...
@@ -106,38 +106,43 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
void
*
workspace
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
cudaStream_t
stream
)
TRT_NOEXCEPT
{
const
auto
input_desc
=
inputDesc
[
0
];
const
auto
input_desc
=
inputDesc
[
0
];
const
float
*
input0
=
static_cast
<
const
float
*>
(
inputs
[
0
]);
const
half
*
input0
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
int32_t
*
input1
=
const
int32_t
*
input1
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
// pos_id_tensor
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
// pos_id_tensor
float
*
output
=
static_cast
<
float
*>
(
outputs
[
0
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
const
auto
input0_desc
=
inputDesc
[
0
];
const
auto
input0_desc
=
inputDesc
[
0
];
const
int32_t
vector_length
=
input0_desc
.
dims
.
d
[
2
];
int32_t
num_threads
;
int32_t
num_threads
;
if
(
input0_desc
.
dims
.
d
[
2
]
%
512
==
0
)
{
if
(
vector_length
<
1024
)
{
num_threads
=
512
;
num_threads
=
vector_length
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
input0_desc
.
dims
.
d
[
2
]
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
}
else
{
num_threads
=
1
;
if
(
vector_length
%
512
==
0
)
{
num_threads
=
512
;
}
else
if
(
vector_length
%
256
==
0
)
{
num_threads
=
256
;
}
else
if
(
vector_length
%
128
==
0
)
{
num_threads
=
128
;
}
else
if
(
vector_length
%
64
==
0
)
{
num_threads
=
64
;
}
else
if
(
vector_length
%
32
==
0
)
{
num_threads
=
32
;
}
else
if
(
vector_length
%
16
==
0
)
{
num_threads
=
16
;
}
else
if
(
vector_length
%
8
==
0
)
{
num_threads
=
8
;
}
else
if
(
vector_length
%
4
==
0
)
{
num_threads
=
4
;
}
else
if
(
vector_length
%
2
==
0
)
{
num_threads
=
2
;
}
else
{
num_threads
=
1
;
}
}
}
const
dim3
num_blocks
(
const
dim3
num_blocks
(
input0_desc
.
dims
.
d
[
0
],
input0_desc
.
dims
.
d
[
0
],
input0_desc
.
dims
.
d
[
1
],
input0_desc
.
dims
.
d
[
1
],
input0_desc
.
dims
.
d
[
2
]
/
vector_length
/
num_threads
);
// batchs, max sequnce length, input
.dims.d[2]/256
num_threads
);
// batchs, max sequnce length, input
0.dims.d[2]/***
RemovePaddingKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
RemovePaddingKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
input0
,
input1
,
output
);
input0
,
input1
,
output
);
...
...
paddle/fluid/inference/tensorrt/plugin/test_fused_token_prune_plugin.cc
浏览文件 @
29782728
...
@@ -26,12 +26,9 @@ TEST(fused_token_prune_op_plugin, test_plugin) {
...
@@ -26,12 +26,9 @@ TEST(fused_token_prune_op_plugin, test_plugin) {
/*keep_first_token*/
false
,
/*keep_first_token*/
false
,
/*keep_order*/
true
,
/*keep_order*/
true
,
/*flag_varseqlen*/
false
);
/*flag_varseqlen*/
false
);
plugin
.
configurePlugin
(
nullptr
,
4
,
nullptr
,
2
);
plugin
.
initialize
();
plugin
.
initialize
();
plugin
.
getPluginType
();
plugin
.
getPluginType
();
plugin
.
getNbOutputs
();
plugin
.
getNbOutputs
();
auto
clone_plugin
=
plugin
.
clone
();
clone_plugin
->
destroy
();
size_t
buf_size
=
plugin
.
getSerializationSize
();
size_t
buf_size
=
plugin
.
getSerializationSize
();
std
::
vector
<
char
>
buf
(
buf_size
);
std
::
vector
<
char
>
buf
(
buf_size
);
plugin
.
serialize
(
buf
.
data
());
plugin
.
serialize
(
buf
.
data
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录