Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2318fb0e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2318fb0e
编写于
11月 05, 2020
作者:
S
Shang Zhizhou
提交者:
GitHub
11月 05, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cherry-pick
ea851796
in develop (#28390)
上级
e9651068
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
895 addition
and
76 deletion
+895
-76
paddle/fluid/inference/analysis/argument.h
paddle/fluid/inference/analysis/argument.h
+1
-0
paddle/fluid/inference/analysis/ir_pass_manager.cc
paddle/fluid/inference/analysis/ir_pass_manager.cc
+1
-0
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
...id/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+14
-1
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+3
-0
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+2
-1
paddle/fluid/inference/api/paddle_analysis_config.h
paddle/fluid/inference/api/paddle_analysis_config.h
+17
-0
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+85
-9
paddle/fluid/inference/tensorrt/convert/matmul_op.cc
paddle/fluid/inference/tensorrt/convert/matmul_op.cc
+90
-0
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+130
-27
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
+39
-6
paddle/fluid/inference/tensorrt/convert/slice_op.cc
paddle/fluid/inference/tensorrt/convert/slice_op.cc
+23
-7
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+10
-2
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+18
-0
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+1
-1
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
...e/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
+177
-0
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h
...le/fluid/inference/tensorrt/plugin/special_slice_plugin.h
+96
-0
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc
...nce/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc
+3
-3
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
...fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
+7
-6
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+3
-1
paddle/fluid/platform/dynload/tensorrt.cc
paddle/fluid/platform/dynload/tensorrt.cc
+27
-10
paddle/fluid/platform/dynload/tensorrt.h
paddle/fluid/platform/dynload/tensorrt.h
+30
-1
paddle/fluid/pybind/inference_api.cc
paddle/fluid/pybind/inference_api.cc
+2
-0
python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py
...fluid/tests/unittests/ir/inference/inference_pass_test.py
+4
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py
...dle/fluid/tests/unittests/ir/inference/test_trt_matmul.py
+111
-0
未找到文件。
paddle/fluid/inference/analysis/argument.h
浏览文件 @
2318fb0e
...
...
@@ -207,6 +207,7 @@ struct Argument {
DECL_ARGUMENT_FIELD
(
tensorrt_use_static_engine
,
TensorRtUseStaticEngine
,
bool
);
DECL_ARGUMENT_FIELD
(
tensorrt_use_calib_mode
,
TensorRtUseCalibMode
,
bool
);
DECL_ARGUMENT_FIELD
(
tensorrt_use_oss
,
TensorRtUseOSS
,
bool
);
DECL_ARGUMENT_FIELD
(
lite_passes_filter
,
LitePassesFilter
,
std
::
vector
<
std
::
string
>
);
...
...
paddle/fluid/inference/analysis/ir_pass_manager.cc
浏览文件 @
2318fb0e
...
...
@@ -95,6 +95,7 @@ void IRPassManager::CreatePasses(Argument *argument,
bool
use_calib_mode
=
argument
->
tensorrt_use_calib_mode
();
pass
->
Set
(
"enable_int8"
,
new
bool
(
enable_int8
));
pass
->
Set
(
"use_calib_mode"
,
new
bool
(
use_calib_mode
));
pass
->
Set
(
"use_oss"
,
new
bool
(
argument
->
tensorrt_use_oss
()));
pass
->
Set
(
"precision_mode"
,
new
AnalysisConfig
::
Precision
(
precision_mode
));
...
...
paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
浏览文件 @
2318fb0e
...
...
@@ -117,11 +117,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
block_desc
.
Proto
()
->
set_idx
(
0
);
LOG
(
INFO
)
<<
"--- detect a sub-graph with "
<<
subgraph
.
size
()
<<
" nodes"
;
bool
has_fused_embedding_eltwise_layernorm
=
false
;
bool
has_multihead_matmul
=
false
;
for
(
auto
*
node
:
subgraph
)
{
auto
*
new_block_op
=
new_block
->
AppendOp
();
auto
*
op
=
block_desc
.
AppendOp
();
*
new_block_op
->
Proto
()
=
*
node
->
Op
()
->
Proto
();
*
op
->
Proto
()
=
*
node
->
Op
()
->
Proto
();
if
(
!
has_fused_embedding_eltwise_layernorm
&&
op
->
Type
()
==
"fused_embedding_eltwise_layernorm"
)
{
has_fused_embedding_eltwise_layernorm
=
true
;
}
if
(
!
has_multihead_matmul
&&
op
->
Type
()
==
"multihead_matmul"
)
{
has_multihead_matmul
=
true
;
}
}
// Then, we will use the input_names_with_id and output_names_with_id to
...
...
@@ -308,6 +317,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
precision_mode
,
calibrator
.
get
(),
Get
<
int
>
(
"gpu_device_id"
),
min_input_shape
,
max_input_shape
,
opt_input_shape
,
disable_trt_plugin_fp16
);
trt_engine
->
SetUseOSS
(
Get
<
bool
>
(
"use_oss"
));
trt_engine
->
SetWithErnie
(
has_multihead_matmul
&&
has_fused_embedding_eltwise_layernorm
);
bool
need_serialize
=
(
use_static_engine
&&
!
load_from_memory
);
if
(
need_serialize
)
{
...
...
@@ -386,4 +398,5 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.
EQ
(
"instance_norm"
,
0
)
.
EQ
(
"gelu"
,
0
)
.
EQ
(
"layer_norm"
,
0
)
.
EQ
(
"scale"
,
0
));
.
EQ
(
"scale"
,
0
)
.
EQ
(
"matmul"
,
0
));
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
2318fb0e
...
...
@@ -121,6 +121,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER
(
tensorrt_precision_mode_
);
CP_MEMBER
(
trt_use_static_engine_
);
CP_MEMBER
(
trt_use_calib_mode_
);
CP_MEMBER
(
trt_use_oss_
);
// MKLDNN related.
CP_MEMBER
(
use_mkldnn_
);
CP_MEMBER
(
mkldnn_enabled_op_types_
);
...
...
@@ -274,6 +275,8 @@ void AnalysisConfig::SetTRTDynamicShapeInfo(
disable_trt_plugin_fp16_
=
disable_trt_plugin_fp16
;
}
void
AnalysisConfig
::
EnableTensorRtOSS
()
{
trt_use_oss_
=
true
;
}
// TODO(Superjomn) refactor this, buggy.
void
AnalysisConfig
::
Update
()
{
auto
info
=
SerializeInfoCache
();
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
2318fb0e
...
...
@@ -470,6 +470,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_
.
SetTensorRtPrecisionMode
(
config_
.
tensorrt_precision_mode_
);
argument_
.
SetTensorRtUseStaticEngine
(
config_
.
trt_use_static_engine_
);
argument_
.
SetTensorRtUseCalibMode
(
config_
.
trt_use_calib_mode_
);
argument_
.
SetTensorRtUseOSS
(
config_
.
trt_use_oss_
);
argument_
.
SetMinInputShape
(
config_
.
min_input_shape_
);
argument_
.
SetMaxInputShape
(
config_
.
max_input_shape_
);
argument_
.
SetOptimInputShape
(
config_
.
optim_input_shape_
);
...
...
@@ -1055,7 +1056,7 @@ USE_TRT_CONVERTER(elementwise_mul_tensor);
USE_TRT_CONVERTER
(
elementwise_max_tensor
);
USE_TRT_CONVERTER
(
elementwise_min_tensor
);
USE_TRT_CONVERTER
(
elementwise_pow_tensor
);
USE_TRT_CONVERTER
(
mul
);
USE_TRT_CONVERTER
(
m
atm
ul
);
USE_TRT_CONVERTER
(
conv2d
);
USE_TRT_CONVERTER
(
relu
);
USE_TRT_CONVERTER
(
sigmoid
);
...
...
paddle/fluid/inference/api/paddle_analysis_config.h
浏览文件 @
2318fb0e
...
...
@@ -312,6 +312,22 @@ struct PD_INFER_DECL AnalysisConfig {
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
,
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape
,
bool
disable_trt_plugin_fp16
=
false
);
///
/// \brief Replace some TensorRT plugins to TensorRT OSS(
/// https://github.com/NVIDIA/TensorRT), with which some models's inference
/// may
/// be more high-performance. Libnvinfer_plugin.so greater than V7.2.1 is
/// needed.
///
void
EnableTensorRtOSS
();
///
/// \brief A boolean state telling whether to use the TensorRT OSS.
///
/// \return bool Whether to use the TensorRT OSS.
///
bool
tensorrt_oss_enabled
()
{
return
trt_use_oss_
;
}
///
/// \brief Turn on the usage of Lite sub-graph engine.
///
...
...
@@ -569,6 +585,7 @@ struct PD_INFER_DECL AnalysisConfig {
Precision
tensorrt_precision_mode_
{
Precision
::
kFloat32
};
bool
trt_use_static_engine_
{
false
};
bool
trt_use_calib_mode_
{
true
};
bool
trt_use_oss_
{
false
};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
min_input_shape_
{};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape_
{};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
optim_input_shape_
{};
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
2318fb0e
# Add TRT tests
nv_library
(
tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
SRCS m
atm
ul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc
...
...
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
2318fb0e
...
...
@@ -49,6 +49,9 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
input_ids
.
push_back
(
engine_
->
GetITensor
(
id_names
[
i
]));
}
// input_embs[0]: word_embedding
// input_embs[1]: pos_embedding
// input_embs[2]: sent_embedding
std
::
vector
<
float
*>
input_embs
;
std
::
vector
<
int
>
emb_sizes
;
...
...
@@ -85,15 +88,91 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
get_persistable_data
(
op_desc
.
Input
(
"Scale"
).
front
(),
&
scale_dims
);
int64_t
bias_size
=
framework
::
product
(
bias_dims
);
int64_t
scale_size
=
framework
::
product
(
scale_dims
);
float
eps
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
auto
use_fp16
=
engine_
->
WithFp16
();
auto
plugin
=
new
plugin
::
EmbEltwiseLayernormPluginDynamic
(
input_embs
,
bias
,
scale
,
emb_sizes
,
bias_size
,
scale_size
,
hidden
,
eps
,
use_fp16
);
layer
=
engine_
->
AddPluginV2
(
input_ids
.
data
(),
input_num
,
plugin
);
if
(
engine_
->
use_oss
())
{
int
output_fp16
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
1
:
0
);
PADDLE_ENFORCE_EQ
(
output_fp16
,
1
,
platform
::
errors
::
InvalidArgument
(
"Only Precision::KHalf(fp16) is supported when infering "
"ernie(bert) model with config.EnableTensorRtOSS(). "
"But Precision::KFloat32 is setted."
));
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"bert_embeddings_layernorm_beta"
,
bias
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
bias_size
)},
{
"bert_embeddings_layernorm_gamma"
,
scale
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
scale_size
)},
{
"bert_embeddings_word_embeddings"
,
input_embs
[
0
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
0
])},
{
"bert_embeddings_token_type_embeddings"
,
input_embs
[
2
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
2
])},
{
"bert_embeddings_position_embeddings"
,
input_embs
[
1
],
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
static_cast
<
int32_t
>
(
emb_sizes
[
1
])},
{
"output_fp16"
,
&
output_fp16
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
};
// remember to free
nvinfer1
::
PluginFieldCollection
*
plugin_ptr
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
plugin_ptr
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
plugin_ptr
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
plugin_ptr
->
fields
=
fields
.
data
();
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
0
)
->
getName
()));
// word_embedding,
// eval_placeholder_0
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
1
)
->
getName
()));
// sent_embedding,
// eval_placeholder_1
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
2
)
->
getName
()));
// cu_seqlens,
// eval_placeholder_2
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
nvinfer1
::
Dims
shape_dim
;
shape_dim
.
nbDims
=
1
;
shape_dim
.
d
[
0
]
=
-
1
;
shuffle_layer
->
setReshapeDimensions
(
shape_dim
);
plugin_inputs
.
emplace_back
(
shuffle_layer
->
getOutput
(
0
));
// max_seqlen, eval_placeholder_3
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomEmbLayerNormPluginDynamic"
,
"2"
);
auto
plugin_obj
=
creator
->
createPlugin
(
"CustomEmbLayerNormPluginDynamic"
,
plugin_ptr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin_obj
);
layer
=
plugin_layer
;
free
(
plugin_ptr
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"emb_eltwise_layernorm"
,
{
output_name
,
std
::
string
(
"qkv_plugin_mask"
)},
test_mode
);
}
else
{
bool
use_fp16
=
engine_
->
WithFp16
();
float
eps
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
plugin
::
DynamicPluginTensorRT
*
plugin
=
nullptr
;
plugin
=
new
plugin
::
EmbEltwiseLayernormPluginDynamic
(
input_embs
,
bias
,
scale
,
emb_sizes
,
bias_size
,
scale_size
,
hidden
,
eps
,
use_fp16
);
layer
=
engine_
->
AddPluginV2
(
input_ids
.
data
(),
input_num
,
plugin
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"emb_eltwise_layernorm"
,
{
output_name
},
test_mode
);
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the Ernie(Bert) model in static"
...
...
@@ -102,9 +181,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
" to set the shape information to run the dynamic shape mode."
));
}
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"emb_eltwise_layernorm"
,
{
output_name
},
test_mode
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the TRT Dynamic Shape mode, need to confirm that "
...
...
paddle/fluid/inference/tensorrt/convert/mul_op.cc
→
paddle/fluid/inference/tensorrt/convert/m
atm
ul_op.cc
浏览文件 @
2318fb0e
...
...
@@ -28,25 +28,54 @@ namespace inference {
namespace
tensorrt
{
/*
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
* M
atM
ulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
*/
class
MulOpConverter
:
public
OpConverter
{
class
M
atM
ulOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
3
)
<<
"convert a fluid mul op to tensorrt mul layer without bias"
;
VLOG
(
3
)
<<
"convert a fluid m
atm
ul op to tensorrt mul layer without bias"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
// Declare inputs
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input2
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Y"
)[
0
]);
// Both the input1 and input2 do not need transpose.
bool
transpose_X
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"transpose_X"
));
bool
transpose_Y
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"transpose_Y"
));
auto
*
layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
MatrixMultiply
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
false
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input2
),
false
);
engine_
,
MatrixMultiply
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input1
),
transpose_X
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
input2
),
transpose_Y
);
float
alpha
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"alpha"
));
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
if
(
fabs
(
alpha
-
1.0
)
<
std
::
numeric_limits
<
float
>::
epsilon
())
{
engine_
->
SetITensor
(
output_name
,
layer
->
getOutput
(
0
));
}
else
{
auto
create_weights
=
[
&
](
float
data
,
const
std
::
string
&
type
)
->
float
*
{
std
::
unique_ptr
<
framework
::
Tensor
>
tmp_tensor
(
new
framework
::
Tensor
());
tmp_tensor
->
Resize
({
1
});
auto
*
tmp_data
=
tmp_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
tmp_data
[
0
]
=
data
;
engine_
->
SetWeights
(
output_name
+
"_add_scale_op_"
+
type
,
std
::
move
(
tmp_tensor
));
return
tmp_data
;
};
float
*
alpha_data
=
create_weights
(
alpha
,
"alpha"
);
float
*
shift_data
=
create_weights
(
0.0
,
"shift"
);
float
*
power_data
=
create_weights
(
1.0
,
"power"
);
TensorRTEngine
::
Weight
nv_alpha
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
alpha_data
),
1
};
TensorRTEngine
::
Weight
nv_shift
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
shift_data
),
1
};
TensorRTEngine
::
Weight
nv_power
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
power_data
),
1
};
auto
*
scale_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Scale
,
*
layer
->
getOutput
(
0
),
nvinfer1
::
ScaleMode
::
kUNIFORM
,
nv_shift
.
get
(),
nv_alpha
.
get
(),
nv_power
.
get
());
engine_
->
SetITensor
(
output_name
,
scale_layer
->
getOutput
(
0
));
}
if
(
test_mode
)
{
// the test framework can not determine which is the
// output, so place the declaration inside.
engine_
->
DeclareOutput
(
output_name
);
...
...
@@ -58,4 +87,4 @@ class MulOpConverter : public OpConverter {
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
m
ul
,
MulOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
m
atmul
,
Mat
MulOpConverter
);
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
2318fb0e
...
...
@@ -30,7 +30,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
// Declare inputs
// Shouble be a 5 dims tensor.
auto
*
input
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Input"
).
front
());
auto
*
input_bias_qk
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"BiasQK"
).
front
());
// fc weights and fc bias
auto
weight_name
=
op_desc
.
Input
(
"W"
).
front
();
...
...
@@ -50,7 +49,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
//
(hidden, 3, all_head_size)
// (hidden, 3, all_head_size)
auto
weight_dims
=
weight_t
->
dims
();
int
hidden
=
weight_dims
[
0
];
// channels_in
...
...
@@ -65,36 +64,140 @@ class MultiheadMatMulOpConverter : public OpConverter {
}
}
};
// transpose weight_data from m * n to n * m
tranpose_weight
(
weight_data_tmp
.
data
(),
weight_data
,
m
,
n
);
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
size_t
>
(
weight_t
->
numel
())};
weight
.
dims
.
assign
({
n
,
m
});
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_data
),
static_cast
<
size_t
>
(
bias_t
->
numel
())};
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
input
,
n
,
weight
.
get
(),
bias
.
get
());
auto
*
fc_out
=
fc_layer
->
getOutput
(
0
);
// add qkv to context
int
head_number
=
BOOST_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"head_number"
));
int
head_size
=
all_head_size
/
head_number
;
float
scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"alpha"
));
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
push_back
(
fc_out
);
plugin_inputs
.
push_back
(
input_bias_qk
);
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
bool
ban_fp16
=
engine_
->
disable_trt_plugin_fp16
();
plugin
::
DynamicPluginTensorRT
*
plugin
=
new
plugin
::
QkvToContextPluginDynamic
(
hidden
,
head_number
,
head_size
,
scale
,
ban_fp16
);
layer
=
engine_
->
AddPluginV2
(
plugin_inputs
.
data
(),
2
,
plugin
);
if
(
engine_
->
use_oss
())
{
int
head_size
=
hidden
/
head_number
;
// [3, Nout, Hout, Nin, Hin] -> [Nout, 3, Hout, Nin, Hin]
auto
transpose_weight_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
N
,
int
H
)
{
const
int
HNH
=
H
*
N
*
H
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
hnh
=
0
;
hnh
<
HNH
;
++
hnh
)
{
dst
[
n
*
3
*
HNH
+
i
*
HNH
+
hnh
]
=
src
[
i
*
N
*
HNH
+
n
*
HNH
+
hnh
];
}
}
}
};
// [3, N, H] -> [N, 3, H]
auto
transpose_bias_v2
=
[](
const
float
*
src
,
float
*
dst
,
int
N
,
int
H
)
{
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
dst
[
n
*
3
*
H
+
i
*
H
+
h
]
=
src
[
i
*
N
*
H
+
n
*
H
+
h
];
}
}
}
};
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
weight_t
->
numel
()
*
sizeof
(
float
));
transpose_weight_v2
(
weight_data_tmp
.
data
(),
weight_data
,
head_number
,
head_size
);
nvinfer1
::
Weights
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
int32_t
>
(
weight_t
->
numel
())};
std
::
vector
<
float
>
bias_data_tmp
;
bias_data_tmp
.
reserve
(
bias_t
->
numel
());
memcpy
(
bias_data_tmp
.
data
(),
bias_data
,
bias_t
->
numel
()
*
sizeof
(
float
));
transpose_bias_v2
(
bias_data_tmp
.
data
(),
bias_data
,
head_number
,
head_size
);
nvinfer1
::
Weights
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_data
),
static_cast
<
int32_t
>
(
bias_t
->
numel
())};
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
input
,
n
,
weight
,
bias
);
auto
mask_tensor
=
engine_
->
GetITensor
(
"qkv_plugin_mask"
);
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomQKVToContextPluginDynamic"
,
"2"
);
assert
(
creator
!=
nullptr
);
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
);
bool
has_mask
=
true
;
int
var_seqlen
=
1
;
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"hidden_size"
,
&
hidden
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"num_heads"
,
&
head_number
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"has_mask"
,
&
has_mask
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"var_seqlen"
,
&
var_seqlen
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
};
nvinfer1
::
PluginFieldCollection
*
plugin_collection
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
plugin_collection
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
// remember to free
plugin_collection
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
plugin_collection
->
fields
=
fields
.
data
();
auto
plugin
=
creator
->
createPlugin
(
"CustomQKVToContextPluginDynamic"
,
plugin_collection
);
free
(
plugin_collection
);
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
emplace_back
(
fc_layer
->
getOutput
(
0
));
plugin_inputs
.
emplace_back
(
mask_tensor
);
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
2
)
->
getName
()));
// cu_seqlens,
// eval_placeholder_2
auto
max_seqlen_tensor
=
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
3
)
->
getName
());
auto
*
shuffle_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
const_cast
<
nvinfer1
::
ITensor
*>
(
max_seqlen_tensor
));
nvinfer1
::
Dims
shape_dim
;
shape_dim
.
nbDims
=
1
;
shape_dim
.
d
[
0
]
=
-
1
;
shuffle_layer
->
setReshapeDimensions
(
shape_dim
);
plugin_inputs
.
emplace_back
(
shuffle_layer
->
getOutput
(
0
));
// max_seqlen, eval_placeholder_3
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
*
plugin
);
layer
=
plugin_layer
;
}
else
{
// transpose weight_data from m * n to n * m
auto
*
input_bias_qk
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"BiasQK"
).
front
());
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
size_t
>
(
weight_t
->
numel
())};
weight
.
dims
.
assign
({
n
,
m
});
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_data
),
static_cast
<
size_t
>
(
bias_t
->
numel
())};
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
input
,
n
,
weight
.
get
(),
bias
.
get
());
auto
*
fc_out
=
fc_layer
->
getOutput
(
0
);
// add qkv to context
int
head_size
=
all_head_size
/
head_number
;
float
scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"alpha"
));
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
plugin_inputs
.
push_back
(
fc_out
);
plugin_inputs
.
push_back
(
input_bias_qk
);
bool
ban_fp16
=
engine_
->
disable_trt_plugin_fp16
();
plugin
::
DynamicPluginTensorRT
*
plugin
=
new
plugin
::
QkvToContextPluginDynamic
(
hidden
,
head_number
,
head_size
,
scale
,
ban_fp16
);
layer
=
engine_
->
AddPluginV2
(
plugin_inputs
.
data
(),
2
,
plugin
);
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the Ernie(Bert) model in static shape mode, which "
...
...
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
浏览文件 @
2318fb0e
...
...
@@ -47,17 +47,50 @@ class SkipLayerNormOpConverter : public OpConverter {
framework
::
DDim
bias_dims
,
scale_dims
;
auto
*
bias
=
get_persistable_data
(
"Bias"
,
&
bias_dims
);
auto
*
scale
=
get_persistable_data
(
"Scale"
,
&
scale_dims
);
float
eps
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
int
bias_size
=
framework
::
product
(
bias_dims
);
int
scale_size
=
framework
::
product
(
scale_dims
);
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
bool
ban_fp16
=
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SkipLayerNormPluginDynamic
*
plugin
=
new
plugin
::
SkipLayerNormPluginDynamic
(
bias
,
scale
,
bias_size
,
scale_size
,
eps
,
ban_fp16
);
layer
=
engine_
->
AddPluginV2
(
inputs
.
data
(),
2
,
plugin
);
if
(
engine_
->
use_oss
())
{
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomSkipLayerNormPluginDynamic"
,
"2"
);
assert
(
creator
!=
nullptr
);
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
);
int
ld
=
input1
->
getDimensions
().
d
[
2
];
// hidden dimension
assert
(
ld
>
0
);
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"ld"
,
&
ld
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"beta"
,
bias
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
bias_size
},
{
"gamma"
,
scale
,
nvinfer1
::
PluginFieldType
::
kFLOAT32
,
scale_size
},
};
nvinfer1
::
PluginFieldCollection
*
pluginPtr
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
pluginPtr
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
// remember to free
pluginPtr
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
pluginPtr
->
fields
=
fields
.
data
();
auto
pluginObj
=
creator
->
createPlugin
(
"CustomSkipLayerNormPluginDynamic"
,
pluginPtr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
inputs
.
data
(),
inputs
.
size
(),
*
pluginObj
);
assert
(
plugin_layer
!=
nullptr
);
layer
=
plugin_layer
;
}
else
{
float
eps
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
bool
ban_fp16
=
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SkipLayerNormPluginDynamic
*
plugin
=
new
plugin
::
SkipLayerNormPluginDynamic
(
bias
,
scale
,
bias_size
,
scale_size
,
eps
,
ban_fp16
);
layer
=
engine_
->
AddPluginV2
(
inputs
.
data
(),
2
,
plugin
);
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the Ernie(Bert) model in static"
...
...
paddle/fluid/inference/tensorrt/convert/slice_op.cc
浏览文件 @
2318fb0e
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -77,16 +78,31 @@ class SliceOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
use_oss
()
&&
engine_
->
with_ernie
())
{
std
::
vector
<
nvinfer1
::
ITensor
*>
plugin_inputs
;
// plugin_inputs.emplace_back(trans_layer->getOutput(0));
plugin_inputs
.
emplace_back
(
input
);
plugin_inputs
.
emplace_back
(
engine_
->
GetITensor
(
engine_
->
network
()
->
getInput
(
2
)
->
getName
()));
// cu_seqlens,
// eval_placeholder_2
// bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin
::
SpecialSlicePluginDynamic
*
plugin
=
new
plugin
::
SpecialSlicePluginDynamic
();
layer
=
engine_
->
AddPluginV2
(
plugin_inputs
.
data
(),
plugin_inputs
.
size
(),
plugin
);
}
else
{
#if IS_TRT_VERSION_GE(6000)
bool
ban_fp16
=
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axes
,
ban_fp16
);
layer
=
engine_
->
AddPluginV2
(
&
input
,
1
,
plugin
);
bool
ban_fp16
=
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SlicePluginDynamic
*
plugin
=
new
plugin
::
SlicePluginDynamic
(
starts
,
ends
,
axes
,
ban_fp16
);
layer
=
engine_
->
AddPluginV2
(
&
input
,
1
,
plugin
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"
));
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"
));
#endif
}
}
else
{
bool
ban_fp16
=
engine_
->
disable_trt_plugin_fp16
();
plugin
::
SlicePlugin
*
plugin
=
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
2318fb0e
...
...
@@ -71,9 +71,9 @@ TRT_DT FluidDataType2TRT(FluidDT type) {
template
<
typename
T
>
nvinfer1
::
Dims
Vec2TRT_Dims
(
const
std
::
vector
<
T
>&
shape
,
std
::
string
input
,
bool
with_dynamic_shape
=
false
)
{
PADDLE_ENFORCE_GT
(
shape
.
size
(),
1
UL
,
PADDLE_ENFORCE_GT
(
shape
.
size
(),
0
UL
,
platform
::
errors
::
InvalidArgument
(
"TensorRT's tensor input requires at least
2
"
"TensorRT's tensor input requires at least
1
"
"dimensions, but input %s has %d dims."
,
input
,
shape
.
size
()));
PADDLE_ENFORCE_LE
(
shape
.
size
(),
4UL
,
...
...
@@ -174,6 +174,7 @@ class TensorRTEngine {
"version should be at least 6."
;
#endif
}
dy
::
initLibNvInferPlugins
(
&
logger
,
""
);
}
~
TensorRTEngine
()
{}
...
...
@@ -285,6 +286,9 @@ class TensorRTEngine {
suffix_counter
+=
1
;
}
void
SetUseOSS
(
bool
use_oss
)
{
use_oss_
=
use_oss
;
}
void
SetWithErnie
(
bool
with_ernie
)
{
with_ernie_
=
with_ernie
;
}
void
ClearWeights
()
{
for
(
auto
&
weight_pair
:
weight_map
)
{
weight_pair
.
second
.
reset
(
nullptr
);
...
...
@@ -312,6 +316,8 @@ class TensorRTEngine {
ShapeMapType
min_input_shape
()
{
return
min_input_shape_
;
}
ShapeMapType
max_input_shape
()
{
return
max_input_shape_
;
}
ShapeMapType
optim_input_shape
()
{
return
optim_input_shape_
;
}
bool
use_oss
()
{
return
use_oss_
;
}
bool
with_ernie
()
{
return
with_ernie_
;
}
bool
disable_trt_plugin_fp16
()
{
return
disable_trt_plugin_fp16_
;
}
bool
with_dynamic_shape
()
{
return
with_dynamic_shape_
;
}
...
...
@@ -347,6 +353,8 @@ class TensorRTEngine {
ShapeMapType
max_input_shape_
;
ShapeMapType
optim_input_shape_
;
bool
disable_trt_plugin_fp16_
{
false
};
bool
use_oss_
{
false
};
bool
with_ernie_
{
false
};
nvinfer1
::
ILogger
&
logger_
;
// max data size for the buffers.
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
2318fb0e
...
...
@@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -70,6 +72,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"hard_swish"
};
std
::
unordered_set
<
std
::
string
>
teller_set
{
"mul"
,
"matmul"
,
"conv2d"
,
"pool2d"
,
"relu"
,
...
...
@@ -122,6 +125,21 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
(
padding_algorithm
==
"SAME"
&&
op_type
!=
"pool2d"
))
return
false
;
}
if
(
op_type
==
"matmul"
)
{
auto
*
block
=
desc
.
Block
();
for
(
auto
&
param_name
:
desc
.
Inputs
())
{
for
(
auto
&
var_name
:
param_name
.
second
)
{
auto
*
var_desc
=
block
->
FindVar
(
var_name
);
const
auto
shape
=
var_desc
->
GetShape
();
if
(
shape
.
size
()
<
3
)
{
VLOG
(
1
)
<<
"matmul op dims < 3 not supported in tensorrt, but got dims "
<<
shape
.
size
()
<<
", so jump it."
;
return
false
;
}
}
}
}
if
((
*
teller
)(
op_type
,
desc
,
use_no_calib_int8
))
return
true
;
}
return
false
;
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
2318fb0e
...
...
@@ -4,5 +4,5 @@ nv_library(tensorrt_plugin
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu
special_slice_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor
)
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu
0 → 100644
浏览文件 @
2318fb0e
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
SpecialSlicePluginDynamic
::
SpecialSlicePluginDynamic
()
{}
SpecialSlicePluginDynamic
::
SpecialSlicePluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
{}
SpecialSlicePluginDynamic
::~
SpecialSlicePluginDynamic
()
{}
nvinfer1
::
IPluginV2DynamicExt
*
SpecialSlicePluginDynamic
::
clone
()
const
{
return
new
SpecialSlicePluginDynamic
();
}
const
char
*
SpecialSlicePluginDynamic
::
getPluginType
()
const
{
return
"special_slice_plugin"
;
}
int
SpecialSlicePluginDynamic
::
getNbOutputs
()
const
{
return
1
;
}
int
SpecialSlicePluginDynamic
::
initialize
()
{
return
0
;
}
size_t
SpecialSlicePluginDynamic
::
getSerializationSize
()
const
{
size_t
serialize_size
=
0
;
return
serialize_size
;
}
void
SpecialSlicePluginDynamic
::
serialize
(
void
*
buffer
)
const
{}
nvinfer1
::
DimsExprs
SpecialSlicePluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
{
nvinfer1
::
DimsExprs
output
(
inputs
[
0
]);
auto
one
=
expr_builder
.
constant
(
1
);
output
.
d
[
0
]
=
expr_builder
.
operation
(
nvinfer1
::
DimensionOperation
::
kSUB
,
*
inputs
[
1
].
d
[
0
],
*
one
);
return
output
;
}
void
SpecialSlicePluginDynamic
::
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
{}
size_t
SpecialSlicePluginDynamic
::
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
{
return
0
;
}
void
SpecialSlicePluginDynamic
::
destroy
()
{
delete
this
;
}
void
SpecialSlicePluginDynamic
::
terminate
()
{}
bool
SpecialSlicePluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
desc
,
int
nb_inputs
,
int
nb_outputs
)
{
if
(
pos
==
0
)
// slice tensor
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
&&
desc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
// || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
if
(
pos
==
1
)
// cu_seqlen
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kINT32
&&
desc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
return
(
desc
[
pos
].
type
==
nvinfer1
::
DataType
::
kHALF
&&
desc
[
pos
].
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
// || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
}
nvinfer1
::
DataType
SpecialSlicePluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The index should be equal to 0"
));
return
input_types
[
0
];
}
template
<
typename
T
>
__global__
void
SpecialSliceKernel
(
const
T
*
slice_input
,
const
int32_t
*
cu_seqlens
,
T
*
output
)
{
const
int
hidden
=
blockDim
.
x
;
const
int
batch
=
blockIdx
.
x
;
output
[
batch
*
hidden
+
threadIdx
.
x
]
=
slice_input
[
cu_seqlens
[
batch
]
*
hidden
+
threadIdx
.
x
];
}
int
SpecialSlicePluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
{
auto
input_dims
=
input_desc
[
0
].
dims
;
// (sum(S), 768, 1, 1)
auto
out_dims
=
output_desc
[
0
].
dims
;
// (batch, 768, 1, 1)
assert
(
input_desc
[
0
].
type
==
nvinfer1
::
DataType
::
kHALF
);
const
int32_t
hidden
=
input_dims
.
d
[
1
];
const
int
num_blocks
=
out_dims
.
d
[
0
];
// batch size
const
int
num_threads
=
hidden
;
const
half
*
slice_input
=
static_cast
<
const
half
*>
(
inputs
[
0
]);
const
int32_t
*
cu_seqlens
=
static_cast
<
const
int32_t
*>
(
inputs
[
1
]);
half
*
output
=
static_cast
<
half
*>
(
outputs
[
0
]);
SpecialSliceKernel
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
slice_input
,
cu_seqlens
,
output
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
SpecialSlicePluginDynamicCreator
::
SpecialSlicePluginDynamicCreator
()
{}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginName
()
const
{
return
"special_slice_plugin"
;
}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginVersion
()
const
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
SpecialSlicePluginDynamicCreator
::
getFieldNames
()
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
SpecialSlicePluginDynamicCreator
::
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
{
return
new
SpecialSlicePluginDynamic
();
}
nvinfer1
::
IPluginV2
*
SpecialSlicePluginDynamicCreator
::
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
{
auto
plugin
=
new
SpecialSlicePluginDynamic
(
serial_data
,
serial_length
);
return
plugin
;
}
void
SpecialSlicePluginDynamicCreator
::
setPluginNamespace
(
const
char
*
lib_namespace
)
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
SpecialSlicePluginDynamicCreator
::
getPluginNamespace
()
const
{
return
plugin_namespace_
.
c_str
();
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h
0 → 100644
浏览文件 @
2318fb0e
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
SpecialSlicePluginDynamic
:
public
DynamicPluginTensorRT
{
public:
SpecialSlicePluginDynamic
();
SpecialSlicePluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
);
~
SpecialSlicePluginDynamic
();
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
override
;
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
outputIndex
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nbInputs
,
nvinfer1
::
IExprBuilder
&
exprBuilder
)
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
override
;
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
override
;
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
override
;
const
char
*
getPluginType
()
const
override
;
int
getNbOutputs
()
const
override
;
int
initialize
()
override
;
void
terminate
()
override
;
size_t
getSerializationSize
()
const
override
;
void
serialize
(
void
*
buffer
)
const
override
;
void
destroy
()
override
;
private:
int
axis_
;
int
num_stack_
;
};
class
SpecialSlicePluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
SpecialSlicePluginDynamicCreator
();
const
char
*
getPluginName
()
const
override
;
const
char
*
getPluginVersion
()
const
override
;
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
override
;
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
override
;
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
override
;
void
setPluginNamespace
(
const
char
*
lib_namespace
)
override
;
const
char
*
getPluginNamespace
()
const
override
;
private:
std
::
string
plugin_namespace_
;
nvinfer1
::
PluginFieldCollection
field_collection_
{
0
,
nullptr
};
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
REGISTER_TRT_PLUGIN_V2
(
SpecialSlicePluginDynamicCreator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc
浏览文件 @
2318fb0e
...
...
@@ -126,17 +126,17 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
{
"read_file_0.tmp_0"
,
min_shape
},
{
"read_file_0.tmp_1"
,
min_shape
},
{
"read_file_0.tmp_2"
,
min_shape
},
{
"
matmul_0.tmp_0"
,
{
batch
,
min_seq_len
,
min_seq_len
}
}};
{
"
read_file_0.tmp_4"
,
min_shape
}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
=
{
{
"read_file_0.tmp_0"
,
max_shape
},
{
"read_file_0.tmp_1"
,
max_shape
},
{
"read_file_0.tmp_2"
,
max_shape
},
{
"
matmul_0.tmp_0"
,
{
batch
,
max_seq_len
,
max_seq_len
}
}};
{
"
read_file_0.tmp_4"
,
max_shape
}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
opt_input_shape
=
{
{
"read_file_0.tmp_0"
,
opt_shape
},
{
"read_file_0.tmp_1"
,
opt_shape
},
{
"read_file_0.tmp_2"
,
opt_shape
},
{
"
matmul_0.tmp_0"
,
{
batch
,
opt_seq_len
,
opt_seq_len
}
}};
{
"
read_file_0.tmp_4"
,
opt_shape
}};
auto
precision
=
AnalysisConfig
::
Precision
::
kFloat32
;
if
(
with_fp16
)
{
...
...
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
浏览文件 @
2318fb0e
...
...
@@ -86,16 +86,16 @@ void run(const AnalysisConfig& config, std::vector<float>* out_data) {
void
trt_ernie
(
bool
with_fp16
,
std
::
vector
<
float
>
result
)
{
AnalysisConfig
config
;
std
::
string
model_dir
=
FLAGS_infer_model
;
SetConfig
(
&
config
,
model_dir
,
true
/* use_gpu */
);
SetConfig
(
&
config
,
model_dir
,
true
);
config
.
SwitchUseFeedFetchOps
(
false
);
int
batch
=
1
;
int
batch
=
32
;
int
min_seq_len
=
1
;
int
max_seq_len
=
128
;
int
opt_seq_len
=
128
;
std
::
vector
<
int
>
min_shape
=
{
batch
,
min_seq_len
,
1
};
std
::
vector
<
int
>
min_shape
=
{
1
,
min_seq_len
,
1
};
std
::
vector
<
int
>
max_shape
=
{
batch
,
max_seq_len
,
1
};
std
::
vector
<
int
>
opt_shape
=
{
batch
,
opt_seq_len
,
1
};
// Set the input's min, max, opt shape
...
...
@@ -103,17 +103,17 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
{
"read_file_0.tmp_0"
,
min_shape
},
{
"read_file_0.tmp_1"
,
min_shape
},
{
"read_file_0.tmp_2"
,
min_shape
},
{
"
matmul_0.tmp_0"
,
{
batch
,
min_seq_len
,
min_seq_len
}
}};
{
"
read_file_0.tmp_4"
,
min_shape
}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
=
{
{
"read_file_0.tmp_0"
,
max_shape
},
{
"read_file_0.tmp_1"
,
max_shape
},
{
"read_file_0.tmp_2"
,
max_shape
},
{
"
matmul_0.tmp_0"
,
{
batch
,
max_seq_len
,
max_seq_len
}
}};
{
"
read_file_0.tmp_4"
,
max_shape
}};
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
opt_input_shape
=
{
{
"read_file_0.tmp_0"
,
opt_shape
},
{
"read_file_0.tmp_1"
,
opt_shape
},
{
"read_file_0.tmp_2"
,
opt_shape
},
{
"
matmul_0.tmp_0"
,
{
batch
,
opt_seq_len
,
opt_seq_len
}
}};
{
"
read_file_0.tmp_4"
,
opt_shape
}};
auto
precision
=
AnalysisConfig
::
Precision
::
kFloat32
;
if
(
with_fp16
)
{
...
...
@@ -124,6 +124,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
opt_input_shape
);
std
::
vector
<
float
>
out_data
;
run
(
config
,
&
out_data
);
for
(
size_t
i
=
0
;
i
<
out_data
.
size
();
i
++
)
{
EXPECT_NEAR
(
result
[
i
],
out_data
[
i
],
1e-5
);
}
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
2318fb0e
...
...
@@ -278,9 +278,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
float
>
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT64
)
{
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
int64_t
>
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT32
)
{
buffers
[
bind_index
]
=
static_cast
<
void
*>
(
t
.
data
<
int32_t
>
());
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The TRT Engine OP only support float
and
int64_t input."
));
"The TRT Engine OP only support float
/int32_t/
int64_t input."
));
}
}
...
...
paddle/fluid/platform/dynload/tensorrt.cc
浏览文件 @
2318fb0e
...
...
@@ -22,19 +22,15 @@ namespace dynload {
std
::
once_flag
tensorrt_dso_flag
;
void
*
tensorrt_dso_handle
;
std
::
once_flag
tensorrt_plugin_dso_flag
;
void
*
tensorrt_plugin_dso_handle
;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
TENSORRT_RAND_ROUTINE_EACH
(
DEFINE_WRAP
);
TENSORRT_PLUGIN_RAND_ROUTINE_EACH
(
DEFINE_WRAP
);
void
*
GetTensorRtHandle
()
{
#if defined(__APPLE__) || defined(__OSX__)
std
::
string
dso_name
=
"libnvinfer.dylib"
;
#elif defined(_WIN32)
std
::
string
dso_name
=
"nvinfer.dll"
;
#else
std
::
string
dso_name
=
"libnvinfer.so"
;
#endif
void
*
GetDsoHandle
(
const
std
::
string
&
dso_name
)
{
#if !defined(_WIN32)
int
dynload_flags
=
RTLD_LAZY
|
RTLD_LOCAL
;
#else
...
...
@@ -49,10 +45,31 @@ void* GetTensorRtHandle() {
"library is not found. Ignore this if TensorRT is not needed."
;
std
::
cerr
<<
error_msg
;
}
return
dso_handle
;
}
void
*
GetTensorRtHandle
()
{
#if defined(__APPLE__) || defined(__OSX__)
std
::
string
dso_name
=
"libnvinfer.dylib"
;
#elif defined(_WIN32)
std
::
string
dso_name
=
"nvinfer.dll"
;
#else
std
::
string
dso_name
=
"libnvinfer.so"
;
#endif
return
GetDsoHandle
(
dso_name
);
}
void
*
GetTensorRtPluginHandle
()
{
#if defined(__APPLE__) || defined(__OSX__)
std
::
string
dso_name
=
"libnvinfer_plugin.dylib"
;
#elif defined(_WIN32)
std
::
string
dso_name
=
"nvinfer_plugin.dll"
;
#else
std
::
string
dso_name
=
"libnvinfer_plugin.so"
;
#endif
return
GetDsoHandle
(
dso_name
);
}
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/tensorrt.h
浏览文件 @
2318fb0e
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <NvInfer.h>
#include <NvInferPlugin.h>
#if !defined(_WIN32)
#include <dlfcn.h>
#endif
...
...
@@ -32,6 +33,10 @@ void* GetTensorRtHandle();
extern
std
::
once_flag
tensorrt_dso_flag
;
extern
void
*
tensorrt_dso_handle
;
void
*
GetTensorRtPluginHandle
();
extern
std
::
once_flag
tensorrt_plugin_dso_flag
;
extern
void
*
tensorrt_plugin_dso_handle
;
#define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
...
...
@@ -50,7 +55,26 @@ extern void* tensorrt_dso_handle;
}; \
extern DynLoad__##__name __name
#define DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \
std::call_once(tensorrt_plugin_dso_flag, []() { \
tensorrt_plugin_dso_handle = \
paddle::platform::dynload::GetTensorRtPluginHandle(); \
}); \
static void* p_##__name = dlsym(tensorrt_plugin_dso_handle, #__name); \
PADDLE_ENFORCE_NOT_NULL(p_##__name, \
platform::errors::Unavailable( \
"Load tensorrt plugin %s failed", #__name)); \
using tensorrt_plugin_func = decltype(&::__name); \
return reinterpret_cast<tensorrt_plugin_func>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#ifdef NV_TENSORRT_MAJOR
#if (NV_TENSORRT_MAJOR >= 6)
#define TENSORRT_RAND_ROUTINE_EACH(__macro) \
__macro(createInferBuilder_INTERNAL); \
...
...
@@ -62,8 +86,13 @@ extern void* tensorrt_dso_handle;
__macro(createInferRuntime_INTERNAL);
#endif
#define TENSORRT_PLUGIN_RAND_ROUTINE_EACH(__macro) \
__macro(initLibNvInferPlugins);
TENSORRT_RAND_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP
)
#endif
TENSORRT_PLUGIN_RAND_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP
)
#endif // end of NV_TENSORRT_MAJOR
}
// namespace dynload
}
// namespace platform
...
...
paddle/fluid/pybind/inference_api.cc
浏览文件 @
2318fb0e
...
...
@@ -481,6 +481,8 @@ void BindAnalysisConfig(py::module *m) {
py
::
arg
(
"optim_input_shape"
)
=
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
({}),
py
::
arg
(
"disable_trt_plugin_fp16"
)
=
false
)
.
def
(
"enable_tensorrt_oss"
,
&
AnalysisConfig
::
EnableTensorRtOSS
)
.
def
(
"tensorrt_oss_enabled"
,
&
AnalysisConfig
::
tensorrt_oss_enabled
)
.
def
(
"tensorrt_engine_enabled"
,
&
AnalysisConfig
::
tensorrt_engine_enabled
)
.
def
(
"enable_lite_engine"
,
&
AnalysisConfig
::
EnableLiteEngine
,
py
::
arg
(
"precision_mode"
)
=
AnalysisConfig
::
Precision
::
kFloat32
,
...
...
python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py
浏览文件 @
2318fb0e
...
...
@@ -20,6 +20,7 @@ import random
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.core
import
PaddleTensor
...
...
@@ -34,6 +35,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
class
InferencePassTest
(
unittest
.
TestCase
):
def
__init__
(
self
,
methodName
=
'runTest'
):
paddle
.
enable_static
()
super
(
InferencePassTest
,
self
).
__init__
(
methodName
)
self
.
main_program
=
fluid
.
Program
()
self
.
startup_program
=
fluid
.
Program
()
...
...
@@ -211,6 +213,7 @@ class InferencePassTest(unittest.TestCase):
if
flatten
:
out
=
out
.
flatten
()
analysis_output
=
analysis_output
.
flatten
()
self
.
assertTrue
(
np
.
allclose
(
out
,
analysis_output
,
atol
=
atol
),
...
...
@@ -232,6 +235,7 @@ class InferencePassTest(unittest.TestCase):
if
flatten
:
out
=
out
.
flatten
()
tensorrt_output
=
tensorrt_output
.
flatten
()
self
.
assertTrue
(
np
.
allclose
(
out
,
tensorrt_output
,
atol
=
atol
),
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py
0 → 100644
浏览文件 @
2318fb0e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
inference_pass_test
import
InferencePassTest
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.core
import
PassVersionChecker
from
paddle.fluid.core
import
AnalysisConfig
class
TensorRTMatMulDims2Test
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
24
,
24
],
dtype
=
"float32"
)
matmul_out
=
fluid
.
layers
.
matmul
(
x
=
data
,
y
=
data
,
transpose_x
=
self
.
transpose_x
,
transpose_y
=
self
.
transpose_y
,
alpha
=
self
.
alpha
)
out
=
fluid
.
layers
.
batch_norm
(
matmul_out
,
is_test
=
True
)
self
.
feeds
=
{
"data"
:
np
.
ones
([
24
,
24
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
TensorRTMatMulDims2Test
.
TensorRTParam
(
1
<<
30
,
32
,
0
,
AnalysisConfig
.
Precision
.
Float32
,
False
,
False
)
self
.
fetch_list
=
[
out
]
def
set_params
(
self
):
self
.
transpose_x
=
True
self
.
transpose_y
=
True
self
.
alpha
=
2.0
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
class
TensorRTMatMulTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data
=
fluid
.
data
(
name
=
"data"
,
shape
=
[
-
1
,
6
,
24
,
24
],
dtype
=
"float32"
)
matmul_out
=
fluid
.
layers
.
matmul
(
x
=
data
,
y
=
data
,
transpose_x
=
self
.
transpose_x
,
transpose_y
=
self
.
transpose_y
,
alpha
=
self
.
alpha
)
out
=
fluid
.
layers
.
batch_norm
(
matmul_out
,
is_test
=
True
)
self
.
feeds
=
{
"data"
:
np
.
ones
([
1
,
6
,
24
,
24
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
TensorRTMatMulTest
.
TensorRTParam
(
1
<<
30
,
32
,
0
,
AnalysisConfig
.
Precision
.
Float32
,
False
,
False
)
self
.
fetch_list
=
[
out
]
def
set_params
(
self
):
self
.
transpose_x
=
False
self
.
transpose_y
=
False
self
.
alpha
=
1.0
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
class
TensorRTMatMulTransposeXTest
(
TensorRTMatMulTest
):
def
set_params
(
self
):
self
.
transpose_x
=
True
self
.
transpose_y
=
False
self
.
alpha
=
1.0
class
TensorRTMatMulTransposeYTest
(
TensorRTMatMulTest
):
def
set_params
(
self
):
self
.
transpose_x
=
False
self
.
transpose_y
=
True
self
.
alpha
=
1.0
class
TensorRTMatMulScaleTest
(
TensorRTMatMulTest
):
def
set_params
(
self
):
self
.
transpose_x
=
False
self
.
transpose_y
=
False
self
.
alpha
=
2.0
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录