Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d43bb7f2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d43bb7f2
编写于
8月 28, 2020
作者:
Z
zlsh80826
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
convert mask with fp32/fp16 support
上级
d4dcc80d
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
127 addition
and
332 deletion
+127
-332
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+0
-19
paddle/fluid/inference/tensorrt/convert/mul_op.cc
paddle/fluid/inference/tensorrt/convert/mul_op.cc
+19
-8
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+15
-42
paddle/fluid/inference/tensorrt/convert/scale_op.cc
paddle/fluid/inference/tensorrt/convert/scale_op.cc
+0
-8
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+2
-2
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu
+0
-85
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h
+0
-120
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu
...le/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu
+69
-41
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h
+20
-5
未找到文件。
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
d43bb7f2
...
@@ -1029,7 +1029,7 @@ USE_TRT_CONVERTER(elementwise_mul_tensor);
...
@@ -1029,7 +1029,7 @@ USE_TRT_CONVERTER(elementwise_mul_tensor);
USE_TRT_CONVERTER
(
elementwise_max_tensor
);
USE_TRT_CONVERTER
(
elementwise_max_tensor
);
USE_TRT_CONVERTER
(
elementwise_min_tensor
);
USE_TRT_CONVERTER
(
elementwise_min_tensor
);
USE_TRT_CONVERTER
(
elementwise_pow_tensor
);
USE_TRT_CONVERTER
(
elementwise_pow_tensor
);
USE_TRT_CONVERTER
(
mul
);
USE_TRT_CONVERTER
(
m
atm
ul
);
USE_TRT_CONVERTER
(
conv2d
);
USE_TRT_CONVERTER
(
conv2d
);
USE_TRT_CONVERTER
(
relu
);
USE_TRT_CONVERTER
(
relu
);
USE_TRT_CONVERTER
(
sigmoid
);
USE_TRT_CONVERTER
(
sigmoid
);
...
...
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
d43bb7f2
...
@@ -11,7 +11,6 @@ limitations under the License. */
...
@@ -11,7 +11,6 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -81,24 +80,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
...
@@ -81,24 +80,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
layer
=
nullptr
;
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
with_dynamic_shape
())
{
auto
pos_tensor
=
engine_
->
GetITensor
(
"eval_placeholder_2"
);
plugin
::
CastIntPluginDynamic
*
cast_plugin
=
new
plugin
::
CastIntPluginDynamic
();
auto
cast_layer
=
engine_
->
AddPluginV2
(
&
pos_tensor
,
1
,
cast_plugin
);
auto
casted_pos_tensor
=
cast_layer
->
getOutput
(
0
);
auto
reshape_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
casted_pos_tensor
);
nvinfer1
::
Dims2
reshape_dim
(
0
,
0
);
nvinfer1
::
Permutation
perm
{
1
,
0
,
2
};
reshape_layer
->
setFirstTranspose
(
perm
);
reshape_layer
->
setReshapeDimensions
(
reshape_dim
);
auto
imask_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Reduce
,
*
reshape_layer
->
getOutput
(
0
),
nvinfer1
::
ReduceOperation
::
kMAX
,
1
,
false
);
engine_
->
SetITensor
(
"imask_tensor"
,
imask_layer
->
getOutput
(
0
));
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
plugin
=
new
plugin
::
EmbEltwiseLayernormPluginDynamic
<
float
>
(
plugin
=
new
plugin
::
EmbEltwiseLayernormPluginDynamic
<
float
>
(
input_embs
,
bias
,
scale
,
emb_sizes
,
bias_size
,
scale_size
,
hidden
,
input_embs
,
bias
,
scale
,
emb_sizes
,
bias_size
,
scale_size
,
hidden
,
...
...
paddle/fluid/inference/tensorrt/convert/mul_op.cc
浏览文件 @
d43bb7f2
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -31,17 +32,27 @@ class MulOpConverter : public OpConverter {
...
@@ -31,17 +32,27 @@ class MulOpConverter : public OpConverter {
// Declare inputs
// Declare inputs
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input2
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Y"
)[
0
]);
auto
*
input2
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Y"
)[
0
]);
bool
transpose_x
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"transpose_X"
));
bool
transpose_y
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"transpose_Y"
));
#ifdef USE_NVINFER_PLUGIN
nvinfer1
::
DataType
type
=
(
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
;
plugin
::
ConvertMaskPluginDynamic
*
plugin
=
new
plugin
::
ConvertMaskPluginDynamic
(
type
);
auto
convert_mask_layer
=
engine_
->
AddPluginV2
(
&
input1
,
1
,
plugin
);
engine_
->
SetITensor
(
"qkv_plugin_mask"
,
convert_mask_layer
->
getOutput
(
0
));
#endif
// Both the input1 and input2 do not need transpose.
// Both the input1 and input2 do not need transpose.
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
MatrixMultiply
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
false
,
engine_
,
MatrixMultiply
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input2
),
false
);
transpose_x
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input2
),
transpose_y
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
RreplenishLayerAndOutput
(
layer
,
"matmul"
,
{
output_name
},
test_mode
);
if
(
test_mode
)
{
// the test framework can not determine which is the
// output, so place the declaration inside.
engine_
->
DeclareOutput
(
output_name
);
}
}
}
};
};
...
@@ -49,4 +60,4 @@ class MulOpConverter : public OpConverter {
...
@@ -49,4 +60,4 @@ class MulOpConverter : public OpConverter {
}
// namespace inference
}
// namespace inference
}
// namespace paddle
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
mul
,
MulOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
m
atm
ul
,
MulOpConverter
);
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
d43bb7f2
...
@@ -113,33 +113,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -113,33 +113,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
static_cast
<
void
*>
(
bias_data
),
static_cast
<
void
*>
(
bias_data
),
static_cast
<
int32_t
>
(
bias_t
->
numel
())};
static_cast
<
int32_t
>
(
bias_t
->
numel
())};
nvinfer1
::
Permutation
permutation
{
0
,
1
,
2
,
3
,
4
};
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
input
,
n
,
auto
trans_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
weight
,
bias
);
trans_layer
->
setFirstTranspose
(
permutation
);
auto
mask_tensor
=
engine_
->
GetITensor
(
"qkv_plugin_mask"
);
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
trans_layer
->
getOutput
(
0
),
n
,
weight
,
bias
);
/*
auto pos_tensor = engine_->GetITensor("eval_placeholder_2");
plugin::CastIntPluginDynamic* cast_plugin =
new plugin::CastIntPluginDynamic();
auto cast_layer = engine_->AddPluginV2(&pos_tensor, 1, cast_plugin);
auto casted_pos_tensor = cast_layer->getOutput(0);
auto reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *casted_pos_tensor);
nvinfer1::Dims2 reshape_dim(0, 0);
nvinfer1::Permutation perm{1, 0, 2};
reshape_layer->setFirstTranspose(perm);
reshape_layer->setReshapeDimensions(reshape_dim);
auto reduce_layer =
TRT_ENGINE_ADD_LAYER(engine_, Reduce,
*reshape_layer->getOutput(0),
nvinfer1::ReduceOperation::kMAX, 1, false);
*/
// auto imask_tensor = engine_->GetITensor("imask_tensor");
auto
imask_tensor
=
engine_
->
GetITensor
(
"fused_mha_mask"
);
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomQKVToContextPluginDynamic"
,
"1"
);
"CustomQKVToContextPluginDynamic"
,
"1"
);
...
@@ -154,28 +131,24 @@ class MultiheadMatMulOpConverter : public OpConverter {
...
@@ -154,28 +131,24 @@ class MultiheadMatMulOpConverter : public OpConverter {
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
};
};
nvinfer1
::
PluginFieldCollection
*
plugin
Ptr
=
nvinfer1
::
PluginFieldCollection
*
plugin
_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
plugin
Ptr
)
+
malloc
(
sizeof
(
*
plugin
_collection
)
+
fields
.
size
()
*
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
// remember to free
sizeof
(
nvinfer1
::
PluginField
)));
// remember to free
pluginPtr
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
plugin_collection
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
pluginPtr
->
fields
=
fields
.
data
();
plugin_collection
->
fields
=
fields
.
data
();
auto
plugin
=
creator
->
createPlugin
(
"CustomQKVToContextPluginDynamic"
,
plugin_collection
);
free
(
plugin_collection
);
auto
pluginObj
=
creator
->
createPlugin
(
"CustomQKVToContextPluginDynamic"
,
pluginPtr
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
push_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
push_back
(
fc_layer
->
getOutput
(
0
));
// plugin_inputs.push_back(reduce_layer->getOutput(0));
plugin_inputs
.
push_back
(
mask_tensor
);
plugin_inputs
.
push_back
(
imask_tensor
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
pluginObj
);
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
assert
(
plugin_layer
!=
nullptr
);
layer
=
plugin_layer
;
auto
trans_r_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
plugin_layer
->
getOutput
(
0
));
assert
(
trans_r_layer
!=
nullptr
);
trans_r_layer
->
setFirstTranspose
(
permutation
);
layer
=
trans_r_layer
;
#else
#else
// transpose weight_data from m * n to n * m
// transpose weight_data from m * n to n * m
auto
*
input_bias_qk
=
auto
*
input_bias_qk
=
...
...
paddle/fluid/inference/tensorrt/convert/scale_op.cc
浏览文件 @
d43bb7f2
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -27,7 +26,6 @@ class ScaleOpConverter : public OpConverter {
...
@@ -27,7 +26,6 @@ class ScaleOpConverter : public OpConverter {
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid scale op to tensorrt mul layer without bias"
;
VLOG
(
3
)
<<
"convert a fluid scale op to tensorrt mul layer without bias"
;
std
::
cerr
<<
"Scale converter"
<<
std
::
endl
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
// Declare inputs
...
@@ -66,12 +64,6 @@ class ScaleOpConverter : public OpConverter {
...
@@ -66,12 +64,6 @@ class ScaleOpConverter : public OpConverter {
platform
::
errors
::
Fatal
(
platform
::
errors
::
Fatal
(
"Paddle-TRT scale mode only support dimension >= 3"
));
"Paddle-TRT scale mode only support dimension >= 3"
));
plugin
::
ConvertMaskPluginDynamic
*
plugin
=
new
plugin
::
ConvertMaskPluginDynamic
();
auto
convert_mask_layer
=
engine_
->
AddPluginV2
(
&
input
,
1
,
plugin
);
convert_mask_layer
->
setName
(
"convert_mask_layer"
);
engine_
->
SetITensor
(
"fused_mha_mask"
,
convert_mask_layer
->
getOutput
(
0
));
nvinfer1
::
IShuffleLayer
*
expand_layer
=
nullptr
;
nvinfer1
::
IShuffleLayer
*
expand_layer
=
nullptr
;
nvinfer1
::
IShuffleLayer
*
squeeze_layer
=
nullptr
;
nvinfer1
::
IShuffleLayer
*
squeeze_layer
=
nullptr
;
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
d43bb7f2
...
@@ -43,7 +43,7 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -43,7 +43,7 @@ struct SimpleOpTypeSetTeller : public Teller {
private:
private:
// use this set for no calib int8.
// use this set for no calib int8.
std
::
unordered_set
<
std
::
string
>
int8_teller_set
{
"mul"
,
std
::
unordered_set
<
std
::
string
>
int8_teller_set
{
"m
atm
ul"
,
"conv2d"
,
"conv2d"
,
"pool2d"
,
"pool2d"
,
"relu"
,
"relu"
,
...
@@ -59,7 +59,7 @@ struct SimpleOpTypeSetTeller : public Teller {
...
@@ -59,7 +59,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_mul"
,
"elementwise_mul"
,
"conv2d_transpose"
};
"conv2d_transpose"
};
std
::
unordered_set
<
std
::
string
>
teller_set
{
std
::
unordered_set
<
std
::
string
>
teller_set
{
"mul"
,
"m
atm
ul"
,
"conv2d"
,
"conv2d"
,
"pool2d"
,
"pool2d"
,
"relu"
,
"relu"
,
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
d43bb7f2
...
@@ -2,7 +2,7 @@ nv_library(tensorrt_plugin
...
@@ -2,7 +2,7 @@ nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
cast_int_plugin.cu
stack_op_plugin.cu convert_mask_plugin.cu
stack_op_plugin.cu convert_mask_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu
已删除
100644 → 0
浏览文件 @
d4dcc80d
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
nvinfer1
::
DimsExprs
CastIntPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
assert
(
output_index
==
0
);
return
inputs
[
0
];
}
bool
CastIntPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
{
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
return
(
in
.
type
==
nvinfer1
::
DataType
::
kINT32
);
}
nvinfer1
::
DataType
CastIntPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The Cast Int only has one input, so the "
"index value should be 0, but get %d."
,
index
));
return
input_types
[
index
];
}
__global__
void
castIntKernel
(
const
int64_t
*
input
,
int32_t
*
output
,
size_t
num_elements
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
>=
num_elements
)
return
;
output
[
idx
]
=
input
[
idx
]
+
1
;
}
int
CastIntPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
input_dims
=
input_desc
[
0
].
dims
;
auto
output_dims
=
output_desc
[
0
].
dims
;
size_t
num_elements
=
ProductDim
(
input_dims
);
size_t
out_num_elements
=
ProductDim
(
output_dims
);
assert
(
input_type
==
nvinfer1
::
DataType
::
kINT32
);
// although the input is int64_t
assert
(
num_elements
==
out_num_elements
);
const
size_t
num_threads
=
256
;
castIntKernel
<<<
num_elements
/
num_threads
+
1
,
num_threads
>>>
(
static_cast
<
const
int64_t
*>
(
inputs
[
0
]),
static_cast
<
int32_t
*>
(
outputs
[
0
]),
num_elements
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h
已删除
100644 → 0
浏览文件 @
d4dcc80d
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
CastIntPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
CastIntPluginDynamic
()
{}
CastIntPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{}
~
CastIntPluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
{
return
new
CastIntPluginDynamic
();
}
const
char
*
getPluginType
()
const
override
{
return
"cast_int_plugin"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
int
initialize
()
override
{
return
0
;
}
size_t
getSerializationSize
()
const
override
{
return
0
;
}
void
serialize
(
void
*
buffer
)
const
override
{}
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nb_inputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nb_outputs
)
override
{}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nb_inputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nb_outputs
)
const
override
{
return
0
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
override
;
void
destroy
()
override
{
delete
this
;
}
};
class
CastIntPluginV2Creator
:
public
nvinfer1
::
IPluginCreator
{
public:
CastIntPluginV2Creator
()
{}
const
char
*
getPluginName
()
const
override
{
return
"cast_int_plugin"
;
}
const
char
*
getPluginVersion
()
const
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
override
{
return
nullptr
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
{
auto
plugin
=
new
CastIntPluginDynamic
(
serial_data
,
serial_length
);
return
plugin
;
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
override
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
override
{
return
plugin_namespace_
.
c_str
();
}
private:
std
::
string
plugin_namespace_
;
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
REGISTER_TRT_PLUGIN_V2
(
CastIntPluginV2Creator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu
浏览文件 @
d43bb7f2
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <vector>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
namespace
paddle
{
namespace
paddle
{
namespace
inference
{
namespace
inference
{
...
@@ -38,15 +39,23 @@ constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128;
...
@@ -38,15 +39,23 @@ constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128;
nvinfer1
::
DimsExprs
ConvertMaskPluginDynamic
::
getOutputDimensions
(
nvinfer1
::
DimsExprs
ConvertMaskPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
auto
cms128
=
expr_builder
.
constant
(
packedMaskSize128
);
assert
(
output_index
==
0
);
auto
fp16maskSize
=
expr_builder
.
operation
(
if
(
type_
==
nvinfer1
::
DataType
::
kHALF
)
{
nvinfer1
::
DimensionOperation
::
kPROD
,
*
cms128
,
*
expr_builder
.
constant
(
2
));
auto
cms128
=
expr_builder
.
constant
(
packedMaskSize128
);
auto
fp16maskSize
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kPROD
,
*
cms128
,
*
expr_builder
.
constant
(
2
));
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
2
;
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
0
];
ret
.
d
[
1
]
=
fp16maskSize
;
return
ret
;
}
nvinfer1
::
DimsExprs
ret
;
nvinfer1
::
DimsExprs
ret
;
ret
.
nbDims
=
2
;
ret
.
nbDims
=
1
;
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
0
];
ret
.
d
[
0
]
=
inputs
[
0
].
d
[
0
];
ret
.
d
[
1
]
=
fp16maskSize
;
return
ret
;
return
ret
;
}
}
...
@@ -54,22 +63,21 @@ bool ConvertMaskPluginDynamic::supportsFormatCombination(
...
@@ -54,22 +63,21 @@ bool ConvertMaskPluginDynamic::supportsFormatCombination(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
{
int
nb_outputs
)
{
const
nvinfer1
::
PluginTensorDesc
&
desc
=
in_out
[
pos
];
const
nvinfer1
::
PluginTensorDesc
&
desc
=
in_out
[
pos
];
/* input: [B, S,
S
] */
/* input: [B, S,
1
] */
/* output: [B, 2*maskSize] */
/* output: [B, 2*maskSize] */
assert
(
nb_inputs
==
1
);
assert
(
nb_inputs
==
1
);
assert
(
nb_outputs
==
1
);
assert
(
nb_outputs
==
1
);
if
(
pos
==
0
)
{
if
(
pos
==
0
)
{
std
::
cerr
<<
"desc.type: "
<<
static_cast
<
int
>
(
desc
.
type
)
<<
" "
<<
desc
.
dims
.
nbDims
<<
std
::
endl
;
return
((
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
return
((
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
||
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
desc
.
dims
.
nbDims
==
3
);
desc
.
dims
.
nbDims
==
3
);
}
}
std
::
cerr
<<
"output.type: "
<<
static_cast
<
int
>
(
desc
.
type
)
<<
" "
// return true;
<<
desc
.
dims
.
nbDims
<<
std
::
endl
;
/* fp16 -> fp16, fp32 -> int32 */
// return desc.type == nvinfer1::DataType::kHALF;
if
(
type_
==
nvinfer1
::
DataType
::
kHALF
)
return
true
;
return
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
;
return
desc
.
type
==
nvinfer1
::
DataType
::
kINT32
;
}
}
nvinfer1
::
DataType
ConvertMaskPluginDynamic
::
getOutputDataType
(
nvinfer1
::
DataType
ConvertMaskPluginDynamic
::
getOutputDataType
(
...
@@ -79,16 +87,36 @@ nvinfer1::DataType ConvertMaskPluginDynamic::getOutputDataType(
...
@@ -79,16 +87,36 @@ nvinfer1::DataType ConvertMaskPluginDynamic::getOutputDataType(
"The convert mask plugin only has one input, so the "
"The convert mask plugin only has one input, so the "
"index value should be 0, but get %d."
,
"index value should be 0, but get %d."
,
index
));
index
));
return
nvinfer1
::
DataType
::
kHALF
;
if
(
type_
==
nvinfer1
::
DataType
::
kHALF
)
{
return
nvinfer1
::
DataType
::
kHALF
;
}
return
nvinfer1
::
DataType
::
kINT32
;
}
}
/* half [B, S, 1] -> int [S, B, 1] */
template
<
typename
T
>
template
<
typename
T
>
__global__
void
CastToIntAndReduce
(
const
T
*
input
,
int
*
output
,
int
seq_len
,
__global__
void
FullMaskPreprocess
(
const
T
*
input
,
int
*
output
,
int
seq_len
,
int
batch
)
{
int
batch
)
{
int
bid
=
blockIdx
.
x
;
int
bid
=
blockIdx
.
x
;
int
sid
=
threadIdx
.
x
;
int
sid
=
threadIdx
.
x
;
output
[
sid
*
batch
+
bid
]
=
output
[
sid
*
batch
+
bid
]
=
static_cast
<
int
>
(
input
[
bid
*
seq_len
+
sid
]);
static_cast
<
int
>
(
input
[
bid
*
seq_len
*
seq_len
+
sid
]);
}
/* float [B, S, 1] -> int [B] */
/* [[1. 1. 1. 0. 0.], -> [3, 4]
[1. 1. 1. 1. 0.]] */
__global__
void
IMaskPreprocess
(
const
float
*
input
,
int
*
output
,
int
seq_len
,
int
batch
)
{
float
sum
=
0.
f
;
int
bid
=
blockIdx
.
x
;
int
sid
=
threadIdx
.
x
;
float
thread_data
=
input
[
bid
*
seq_len
+
sid
];
sum
=
paddle
::
operators
::
math
::
blockReduceSum
<
float
>
(
thread_data
,
0xffffffff
);
if
(
sid
==
0
)
{
output
[
bid
]
=
static_cast
<
int
>
(
sum
);
}
}
}
__global__
void
fillSBSMaskKernel
(
const
uint32_t
warps_m
,
__global__
void
fillSBSMaskKernel
(
const
uint32_t
warps_m
,
...
@@ -159,33 +187,33 @@ int ConvertMaskPluginDynamic::enqueue(
...
@@ -159,33 +187,33 @@ int ConvertMaskPluginDynamic::enqueue(
int
batch
=
input_dims
.
d
[
0
];
int
batch
=
input_dims
.
d
[
0
];
int
seq_len
=
input_dims
.
d
[
1
];
int
seq_len
=
input_dims
.
d
[
1
];
assert
(
num_elements
==
out_num_elements
*
seq_len
);
assert
(
seq_len
==
128
);
assert
(
seq_len
<=
1024
);
assert
(
output_desc
.
type
==
nvinfer1
::
DataType
::
kHALF
);
// temp use, should remove
int
*
inputMaskSB
;
cudaMalloc
(
&
inputMaskSB
,
batch
*
seq_len
*
sizeof
(
int
));
if
(
input_desc
[
0
].
type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
if
(
type_
==
nvinfer1
::
DataType
::
kFLOAT
)
{
CastToIntAndReduce
<
float
><<<
batch
,
seq_len
,
0
,
stream
>>>
(
IMaskPreprocess
<<<
batch
,
seq_len
,
0
,
stream
>>>
(
static_cast
<
const
float
*>
(
inputs
[
0
]),
inputMaskSB
,
seq_len
,
batch
);
static_cast
<
const
float
*>
(
inputs
[
0
]),
static_cast
<
int
*>
(
outputs
[
0
]),
seq_len
,
batch
);
}
else
{
}
else
{
CastToIntAndReduce
<
half
><<<
batch
,
seq_len
,
0
,
stream
>>>
(
int
*
inputMaskSB
;
static_cast
<
const
half
*>
(
inputs
[
0
]),
inputMaskSB
,
seq_len
,
batch
);
cudaMalloc
(
&
inputMaskSB
,
batch
*
seq_len
*
sizeof
(
int
));
if
(
input_desc
[
0
].
type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
FullMaskPreprocess
<
float
><<<
batch
,
seq_len
,
0
,
stream
>>>
(
static_cast
<
const
float
*>
(
inputs
[
0
]),
inputMaskSB
,
seq_len
,
batch
);
}
else
{
FullMaskPreprocess
<
half
><<<
batch
,
seq_len
,
0
,
stream
>>>
(
static_cast
<
const
half
*>
(
inputs
[
0
]),
inputMaskSB
,
seq_len
,
batch
);
}
size_t
warps_m
=
0
,
warps_n
=
0
,
warps_k
=
1
;
if
(
seq_len
==
128
)
{
warps_m
=
2
;
warps_n
=
2
;
}
convertMask
(
seq_len
,
batch
,
warps_m
,
warps_n
,
warps_k
,
inputMaskSB
,
static_cast
<
uint32_t
*>
(
outputs
[
0
]),
stream
);
cudaFree
(
inputMaskSB
);
}
}
assert
(
seq_len
==
128
);
size_t
warps_m
=
0
,
warps_n
=
0
,
warps_k
=
1
;
if
(
seq_len
==
128
)
{
warps_m
=
2
;
warps_n
=
2
;
}
convertMask
(
seq_len
,
batch
,
warps_m
,
warps_n
,
warps_k
,
inputMaskSB
,
static_cast
<
uint32_t
*>
(
outputs
[
0
]),
stream
);
cudaFree
(
inputMaskSB
);
return
cudaGetLastError
()
!=
cudaSuccess
;
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
#endif
#endif
...
...
paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h
浏览文件 @
d43bb7f2
...
@@ -27,20 +27,32 @@ namespace plugin {
...
@@ -27,20 +27,32 @@ namespace plugin {
#if IS_TRT_VERSION_GE(6000)
#if IS_TRT_VERSION_GE(6000)
class
ConvertMaskPluginDynamic
:
public
DynamicPluginTensorRT
{
class
ConvertMaskPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
public:
ConvertMaskPluginDynamic
()
{}
explicit
ConvertMaskPluginDynamic
(
nvinfer1
::
DataType
type
)
:
type_
(
type
)
{
ConvertMaskPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{}
assert
(
type
==
nvinfer1
::
DataType
::
kHALF
||
type
==
nvinfer1
::
DataType
::
kFLOAT
);
}
ConvertMaskPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
type_
);
}
~
ConvertMaskPluginDynamic
()
{}
~
ConvertMaskPluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
{
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
{
return
new
ConvertMaskPluginDynamic
();
return
new
ConvertMaskPluginDynamic
(
type_
);
}
}
const
char
*
getPluginType
()
const
override
{
return
"convert_mask_plugin"
;
}
const
char
*
getPluginType
()
const
override
{
return
"convert_mask_plugin"
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
int
getNbOutputs
()
const
override
{
return
1
;
}
int
initialize
()
override
{
return
0
;
}
int
initialize
()
override
{
return
0
;
}
size_t
getSerializationSize
()
const
override
{
return
0
;
}
size_t
getSerializationSize
()
const
override
{
void
serialize
(
void
*
buffer
)
const
override
{}
size_t
serialize_size
=
0
;
serialize_size
+=
SerializedSize
(
type_
);
return
serialize_size
;
}
void
serialize
(
void
*
buffer
)
const
override
{
SerializeValue
(
&
buffer
,
type_
);
}
nvinfer1
::
DimsExprs
getOutputDimensions
(
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
...
@@ -71,6 +83,9 @@ class ConvertMaskPluginDynamic : public DynamicPluginTensorRT {
...
@@ -71,6 +83,9 @@ class ConvertMaskPluginDynamic : public DynamicPluginTensorRT {
int
nb_inputs
)
const
override
;
int
nb_inputs
)
const
override
;
void
destroy
()
override
{
delete
this
;
}
void
destroy
()
override
{
delete
this
;
}
private:
nvinfer1
::
DataType
type_
;
};
};
class
ConvertMaskPluginV2Creator
:
public
nvinfer1
::
IPluginCreator
{
class
ConvertMaskPluginV2Creator
:
public
nvinfer1
::
IPluginCreator
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录