Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
d4dcc80d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d4dcc80d
编写于
8月 27, 2020
作者:
Z
zlsh80826
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MHA fp16
上级
03acac2b
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
327 addition
and
8 deletion
+327
-8
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+2
-1
paddle/fluid/inference/tensorrt/convert/op_converter.h
paddle/fluid/inference/tensorrt/convert/op_converter.h
+0
-2
paddle/fluid/inference/tensorrt/convert/scale_op.cc
paddle/fluid/inference/tensorrt/convert/scale_op.cc
+8
-0
paddle/fluid/inference/tensorrt/convert/ut_helper.h
paddle/fluid/inference/tensorrt/convert/ut_helper.h
+0
-2
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu
...le/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu
+196
-0
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h
+120
-0
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+0
-2
未找到文件。
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
d4dcc80d
...
...
@@ -138,7 +138,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
*reshape_layer->getOutput(0),
nvinfer1::ReduceOperation::kMAX, 1, false);
*/
auto
imask_tensor
=
engine_
->
GetITensor
(
"imask_tensor"
);
// auto imask_tensor = engine_->GetITensor("imask_tensor");
auto
imask_tensor
=
engine_
->
GetITensor
(
"fused_mha_mask"
);
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomQKVToContextPluginDynamic"
,
"1"
);
...
...
paddle/fluid/inference/tensorrt/convert/op_converter.h
浏览文件 @
d4dcc80d
...
...
@@ -173,8 +173,6 @@ class OpConverter {
"optim_input_shape should be same."
));
}
}
std
::
cerr
<<
"Declare input: "
<<
input
<<
std
::
endl
;
if
(
input
.
find
(
"stack_0.tmp_0"
)
!=
std
::
string
::
npos
)
continue
;
engine
->
DeclareInput
(
input
,
FluidDataType2TRT
(
var
->
Proto
()
->
type
().
lod_tensor
().
tensor
().
data_type
()),
...
...
paddle/fluid/inference/tensorrt/convert/scale_op.cc
浏览文件 @
d4dcc80d
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -26,6 +27,7 @@ class ScaleOpConverter : public OpConverter {
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid scale op to tensorrt mul layer without bias"
;
std
::
cerr
<<
"Scale converter"
<<
std
::
endl
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
...
...
@@ -64,6 +66,12 @@ class ScaleOpConverter : public OpConverter {
platform
::
errors
::
Fatal
(
"Paddle-TRT scale mode only support dimension >= 3"
));
plugin
::
ConvertMaskPluginDynamic
*
plugin
=
new
plugin
::
ConvertMaskPluginDynamic
();
auto
convert_mask_layer
=
engine_
->
AddPluginV2
(
&
input
,
1
,
plugin
);
convert_mask_layer
->
setName
(
"convert_mask_layer"
);
engine_
->
SetITensor
(
"fused_mha_mask"
,
convert_mask_layer
->
getOutput
(
0
));
nvinfer1
::
IShuffleLayer
*
expand_layer
=
nullptr
;
nvinfer1
::
IShuffleLayer
*
squeeze_layer
=
nullptr
;
...
...
paddle/fluid/inference/tensorrt/convert/ut_helper.h
浏览文件 @
d4dcc80d
...
...
@@ -183,8 +183,6 @@ class TRTConvertValidation {
std
::
vector
<
void
*>
buffers
(
num_bindings
);
for
(
const
std
::
string
&
name
:
input_output_names
)
{
// std::cerr << "Binding name: " << name << std::endl;
if
(
name
.
find
(
"stack_0.tmp_0"
)
!=
std
::
string
::
npos
)
continue
;
auto
*
var
=
scope_
.
FindVar
(
name
);
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
const
int
bind_index
=
engine_
->
engine
()
->
getBindingIndex
(
name
.
c_str
());
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
d4dcc80d
...
...
@@ -2,7 +2,7 @@ nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
cast_int_plugin.cu stack_op_plugin.cu
cast_int_plugin.cu stack_op_plugin.cu
convert_mask_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu
0 → 100644
浏览文件 @
d4dcc80d
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
/* This plugin currently converts the matmul output [B, S, S]
to the mask with the bertQKV fused_multihead_attention format */
constexpr
size_t
threadsPerCta128
=
2
*
2
*
32
;
constexpr
size_t
xmmasM128
=
4
;
constexpr
size_t
packedMaskSize128
=
xmmasM128
*
threadsPerCta128
;
nvinfer1
::
DimsExprs
ConvertMaskPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
auto
cms128
=
expr_builder
.
constant
(
packedMaskSize128
);
auto
fp16maskSize
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
cms128
,
*
expr_builder
.
constant
(
2
));
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
2
;
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
0
];
ret
.
d
[
1
]
=
fp16maskSize
;
return
ret
;
}
bool
ConvertMaskPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
{
const
nvinfer1
::
PluginTensorDesc
&
desc
=
in_out
[
pos
];
/* input: [B, S, S] */
/* output: [B, 2*maskSize] */
assert
(
nb_inputs
==
1
);
assert
(
nb_outputs
==
1
);
if
(
pos
==
0
)
{
std
::
cerr
<<
"desc.type: "
<<
static_cast
<
int
>
(
desc
.
type
)
<<
" "
<<
desc
.
dims
.
nbDims
<<
std
::
endl
;
return
((
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
desc
.
dims
.
nbDims
==
3
);
}
std
::
cerr
<<
"output.type: "
<<
static_cast
<
int
>
(
desc
.
type
)
<<
" "
<<
desc
.
dims
.
nbDims
<<
std
::
endl
;
// return desc.type == nvinfer1::DataType::kHALF;
return
true
;
}
nvinfer1
::
DataType
ConvertMaskPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The convert mask plugin only has one input, so the "
"index value should be 0, but get %d."
,
index
));
return
nvinfer1
::
DataType
::
kHALF
;
}
template
<
typename
T
>
__global__
void
CastToIntAndReduce
(
const
T
*
input
,
int
*
output
,
int
seq_len
,
int
batch
)
{
int
bid
=
blockIdx
.
x
;
int
sid
=
threadIdx
.
x
;
output
[
sid
*
batch
+
bid
]
=
static_cast
<
int
>
(
input
[
bid
*
seq_len
*
seq_len
+
sid
]);
}
__global__
void
fillSBSMaskKernel
(
const
uint32_t
warps_m
,
const
uint32_t
warps_n
,
const
uint32_t
S
,
const
int
*
inputMaskSB
,
uint32_t
*
inputMaskX
)
{
extern
__shared__
int
shm_mask
[];
// S mask elements of this batch
const
size_t
xmmas_n
=
(
S
+
16
*
warps_n
-
1
)
/
(
16
*
warps_n
);
const
uint32_t
threads_per_cta
=
blockDim
.
x
;
const
uint32_t
xmmas_m
=
gridDim
.
x
;
const
uint32_t
B
=
gridDim
.
y
;
const
uint32_t
mi
=
blockIdx
.
x
;
const
uint32_t
bi
=
blockIdx
.
y
;
const
uint32_t
tidx
=
threadIdx
.
x
;
const
size_t
warp
=
tidx
/
32
;
const
size_t
warp_m
=
warp
%
warps_m
;
const
size_t
warp_n
=
warp
/
warps_m
;
const
size_t
lane
=
tidx
%
32
;
const
size_t
col
=
warp_n
*
16
+
lane
%
4
*
2
;
// load the mask corresponding to one batch
for
(
uint32_t
si
=
tidx
;
si
<
S
;
si
+=
threads_per_cta
)
{
// not coalesced to conform to current input format: SxB
shm_mask
[
si
]
=
inputMaskSB
[
si
*
B
+
bi
];
}
__syncthreads
();
uint32_t
mask
=
0u
;
for
(
size_t
ni
=
0
;
ni
<
xmmas_n
;
++
ni
)
{
const
int
offset
=
ni
*
16
*
warps_n
+
col
;
mask
|=
(
shm_mask
[
offset
+
0
]
==
1.
f
?
1u
:
0u
)
<<
(
8
*
ni
+
0
);
mask
|=
(
shm_mask
[
offset
+
1
]
==
1.
f
?
1u
:
0u
)
<<
(
8
*
ni
+
1
);
mask
|=
(
shm_mask
[
offset
+
0
]
==
1.
f
?
1u
:
0u
)
<<
(
8
*
ni
+
2
);
mask
|=
(
shm_mask
[
offset
+
1
]
==
1.
f
?
1u
:
0u
)
<<
(
8
*
ni
+
3
);
mask
|=
(
shm_mask
[
offset
+
8
]
==
1.
f
?
1u
:
0u
)
<<
(
8
*
ni
+
4
);
mask
|=
(
shm_mask
[
offset
+
9
]
==
1.
f
?
1u
:
0u
)
<<
(
8
*
ni
+
5
);
mask
|=
(
shm_mask
[
offset
+
8
]
==
1.
f
?
1u
:
0u
)
<<
(
8
*
ni
+
6
);
mask
|=
(
shm_mask
[
offset
+
9
]
==
1.
f
?
1u
:
0u
)
<<
(
8
*
ni
+
7
);
}
inputMaskX
[(
bi
*
xmmas_m
+
mi
)
*
threads_per_cta
+
tidx
]
=
mask
;
}
void
convertMask
(
const
uint32_t
S
,
const
uint32_t
B
,
const
uint32_t
warps_m
,
const
uint32_t
warps_n
,
const
uint32_t
warps_k
,
const
int
*
inputMaskSB
,
uint32_t
*
inputMaskX
,
cudaStream_t
stream
)
{
const
size_t
xmmas_m
=
(
S
+
16
*
warps_m
-
1
)
/
(
16
*
warps_m
);
const
size_t
threads_per_cta
=
warps_m
*
warps_n
*
warps_k
*
32
;
dim3
grid
(
xmmas_m
,
B
);
fillSBSMaskKernel
<<<
grid
,
threads_per_cta
,
S
*
sizeof
(
int
),
stream
>>>
(
warps_m
,
warps_n
,
S
,
inputMaskSB
,
inputMaskX
);
}
int
ConvertMaskPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
input_dims
=
input_desc
[
0
].
dims
;
auto
output_dims
=
output_desc
[
0
].
dims
;
size_t
num_elements
=
ProductDim
(
input_dims
);
size_t
out_num_elements
=
ProductDim
(
output_dims
);
int
batch
=
input_dims
.
d
[
0
];
int
seq_len
=
input_dims
.
d
[
1
];
assert
(
num_elements
==
out_num_elements
*
seq_len
);
assert
(
seq_len
<=
1024
);
assert
(
output_desc
.
type
==
nvinfer1
::
DataType
::
kHALF
);
// temp use, should remove
int
*
inputMaskSB
;
cudaMalloc
(
&
inputMaskSB
,
batch
*
seq_len
*
sizeof
(
int
));
if
(
input_desc
[
0
].
type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
CastToIntAndReduce
<
float
><<<
batch
,
seq_len
,
0
,
stream
>>>
(
static_cast
<
const
float
*>
(
inputs
[
0
]),
inputMaskSB
,
seq_len
,
batch
);
}
else
{
CastToIntAndReduce
<
half
><<<
batch
,
seq_len
,
0
,
stream
>>>
(
static_cast
<
const
half
*>
(
inputs
[
0
]),
inputMaskSB
,
seq_len
,
batch
);
}
assert
(
seq_len
==
128
);
size_t
warps_m
=
0
,
warps_n
=
0
,
warps_k
=
1
;
if
(
seq_len
==
128
)
{
warps_m
=
2
;
warps_n
=
2
;
}
convertMask
(
seq_len
,
batch
,
warps_m
,
warps_n
,
warps_k
,
inputMaskSB
,
static_cast
<
uint32_t
*>
(
outputs
[
0
]),
stream
);
cudaFree
(
inputMaskSB
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h
0 → 100644
浏览文件 @
d4dcc80d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
ConvertMaskPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
ConvertMaskPluginDynamic
()
{}
ConvertMaskPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{}
~
ConvertMaskPluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
{
return
new
ConvertMaskPluginDynamic
();
}
const
char
*
getPluginType
()
const
override
{
return
"convert_mask_plugin"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
int
initialize
()
override
{
return
0
;
}
size_t
getSerializationSize
()
const
override
{
return
0
;
}
void
serialize
(
void
*
buffer
)
const
override
{}
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nb_inputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nb_outputs
)
override
{}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nb_inputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nb_outputs
)
const
override
{
return
0
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
override
;
void
destroy
()
override
{
delete
this
;
}
};
class
ConvertMaskPluginV2Creator
:
public
nvinfer1
::
IPluginCreator
{
public:
ConvertMaskPluginV2Creator
()
{}
const
char
*
getPluginName
()
const
override
{
return
"convert_mask_plugin"
;
}
const
char
*
getPluginVersion
()
const
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
override
{
return
nullptr
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
{
auto
plugin
=
new
ConvertMaskPluginDynamic
(
serial_data
,
serial_length
);
return
plugin
;
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
override
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
override
{
return
plugin_namespace_
.
c_str
();
}
private:
std
::
string
plugin_namespace_
;
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
REGISTER_TRT_PLUGIN_V2
(
ConvertMaskPluginV2Creator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
d4dcc80d
...
...
@@ -227,8 +227,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
// Bind input tensor to TRT.
for
(
const
auto
&
x
:
Inputs
(
"Xs"
))
{
if
(
param_names_
.
count
(
x
))
continue
;
// std::cerr << "runTRT name: " << x << std::endl;
if
(
x
.
find
(
"stack_0.tmp_0"
)
!=
std
::
string
::
npos
)
continue
;
// convert input and copy to TRT engine's buffer
auto
&
t
=
inference
::
analysis
::
GetFromScope
<
framework
::
LoDTensor
>
(
scope
,
x
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录