Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b50dbe0b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b50dbe0b
编写于
12月 19, 2022
作者:
W
Wangzheee
提交者:
GitHub
12月 19, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] General optimization for no_varlen skiplayernorm (#49039)
* General optimization for no_varlen embedding layernorm
上级
9df0ab32
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
369 addition
and
991 deletion
+369
-991
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
+15
-2
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+21
-18
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
+115
-160
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+0
-1
paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.cu
...uid/inference/tensorrt/plugin/skip_layernorm_op_plugin.cu
+0
-219
paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h
...luid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h
+0
-353
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
...fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
+2
-2
python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py
...fluid/tests/unittests/ir/inference/inference_pass_test.py
+1
-1
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_skip_layernorm.py
...unittests/ir/inference/test_trt_convert_skip_layernorm.py
+0
-235
python/paddle/fluid/tests/unittests/ir/inference/test_trt_skip_layernorm_fuse_pass.py
...ittests/ir/inference/test_trt_skip_layernorm_fuse_pass.py
+215
-0
未找到文件。
paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc
浏览文件 @
b50dbe0b
...
...
@@ -169,8 +169,21 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
// attrs
new_desc
.
SetAttr
(
"epsilon"
,
layer_norm
->
Op
()
->
GetAttr
(
"epsilon"
));
new_desc
.
SetAttr
(
"begin_norm_axis"
,
layer_norm
->
Op
()
->
GetAttr
(
"begin_norm_axis"
));
if
(
new_desc
.
HasAttr
(
"begin_norm_axis"
))
{
int32_t
begin_norm_axis
=
PADDLE_GET_CONST
(
int32_t
,
layer_norm
->
Op
()
->
GetAttr
(
"begin_norm_axis"
));
int32_t
input_rank
=
static_cast
<
int32_t
>
(
elementwise_out
->
Var
()
->
GetShape
().
size
());
if
((
begin_norm_axis
!=
-
1
)
&&
(
begin_norm_axis
!=
input_rank
-
1
))
{
LOG
(
WARNING
)
<<
"skip_layernorm pass only support "
"layer_norm'begin_norm_axis == input_rank - 1."
;
return
;
}
new_desc
.
SetAttr
(
"begin_norm_axis"
,
begin_norm_axis
);
}
int32_t
hidden_size
=
layer_norm_scale
->
Var
()
->
GetShape
()[
0
];
new_desc
.
SetAttr
(
"hidden_size"
,
hidden_size
);
auto
fused_node
=
graph
->
CreateOpNode
(
&
new_desc
);
// OpDesc will be copied.
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
b50dbe0b
...
...
@@ -104,24 +104,27 @@ const std::vector<std::string> kTRTSubgraphPasses({
"multihead_matmul_roformer_fuse_pass"
,
//
"constant_folding_pass"
,
//
"vit_attention_fuse_pass"
,
//
"trt_skip_layernorm_fuse_pass"
,
//
"preln_skip_layernorm_fuse_pass"
,
//
"layernorm_shift_partition_fuse_pass"
,
//
"merge_layernorm_fuse_pass"
,
//
"preln_residual_bias_fuse_pass"
,
//
"preln_layernorm_x_fuse_pass"
,
//
"reverse_roll_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"unsqueeze2_eltwise_fuse_pass"
,
//
"trt_squeeze2_matmul_fuse_pass"
,
//
"trt_flatten2_matmul_fuse_pass"
,
//
"trt_map_matmul_v2_to_mul_pass"
,
//
"trt_map_matmul_v2_to_matmul_pass"
,
//
"trt_map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"remove_padding_recover_padding_pass"
,
//
"delete_remove_padding_recover_padding_pass"
,
//
#if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading.
#else
"trt_skip_layernorm_fuse_pass"
,
//
"preln_skip_layernorm_fuse_pass"
,
//
#endif
"layernorm_shift_partition_fuse_pass"
,
//
"merge_layernorm_fuse_pass"
,
//
"preln_residual_bias_fuse_pass"
,
//
"preln_layernorm_x_fuse_pass"
,
//
"reverse_roll_fuse_pass"
,
//
"conv_bn_fuse_pass"
,
//
"unsqueeze2_eltwise_fuse_pass"
,
//
"trt_squeeze2_matmul_fuse_pass"
,
//
"trt_flatten2_matmul_fuse_pass"
,
//
"trt_map_matmul_v2_to_mul_pass"
,
//
"trt_map_matmul_v2_to_matmul_pass"
,
//
"trt_map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"remove_padding_recover_padding_pass"
,
//
"delete_remove_padding_recover_padding_pass"
,
//
// "yolo_box_fuse_pass", //
"dense_fc_to_sparse_pass"
,
//
"dense_multihead_matmul_to_sparse_pass"
,
//
...
...
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
浏览文件 @
b50dbe0b
/* Copyright (c) 20
18
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 20
22
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/utils.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -26,9 +25,20 @@ class SkipLayerNormOpConverter : public OpConverter {
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
#if IS_TRT_VERSION_GE(6000)
VLOG
(
4
)
<<
"convert fused skip layernorm op to tensorrt layer"
;
PADDLE_ENFORCE_EQ
(
engine_
->
with_dynamic_shape
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Skip_layernorm must run the dynamic shape mode."
));
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
GetWeight
=
[
&
](
const
std
::
string
&
arg_name
)
->
TensorRTEngine
::
Weight
{
std
::
string
var_name
=
op_desc
.
Input
(
arg_name
).
front
();
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
weight
=
engine_
->
GetTrtWeight
(
var_name
,
*
temp_tensor
);
return
weight
;
};
// Declare inputs
auto
*
input1
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
)[
0
]);
auto
*
input2
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Y"
)[
0
]);
...
...
@@ -36,173 +46,118 @@ class SkipLayerNormOpConverter : public OpConverter {
inputs
.
push_back
(
input1
);
inputs
.
push_back
(
input2
);
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
bool
enable_int8
=
false
;
if
(
op_desc
.
HasAttr
(
"enable_int8"
))
{
enable_int8
=
PADDLE_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"enable_int8"
));
}
auto
bias_weight
=
GetWeight
(
"Bias"
).
get
();
auto
scale_weight
=
GetWeight
(
"Scale"
).
get
();
nvinfer1
::
ILayer
*
layer
=
nullptr
;
bool
flag_varseqlen
=
engine_
->
use_varseqlen
()
&&
engine_
->
tensorrt_transformer_posid
()
!=
""
&&
engine_
->
tensorrt_transformer_maskid
()
!=
""
;
if
(
flag_varseqlen
)
{
auto
GetWeight
=
[
&
](
const
std
::
string
&
arg_name
)
->
TensorRTEngine
::
Weight
{
std
::
string
var_name
=
op_desc
.
Input
(
arg_name
).
front
();
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
weight
=
engine_
->
GetTrtWeight
(
var_name
,
*
temp_tensor
);
return
weight
;
};
auto
bias_weight
=
GetWeight
(
"Bias"
).
get
();
auto
scale_weight
=
GetWeight
(
"Scale"
).
get
();
if
(
engine_
->
with_interleaved
())
{
VLOG
(
4
)
<<
"fused skip_layernorm op: use_varseqlen and with_interleaved"
;
if
(
!
enable_int8
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
}
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomSkipLayerNormPluginDynamic"
,
"3"
);
PADDLE_ENFORCE_NE
(
creator
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"fail to get creator of CustomSkipLayerNormPluginDynamic"
));
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"beta"
,
bias_weight
.
values
,
GetPluginFieldType
(
bias_weight
.
type
),
static_cast
<
int32_t
>
(
bias_weight
.
count
)},
{
"gamma"
,
scale_weight
.
values
,
GetPluginFieldType
(
scale_weight
.
type
),
static_cast
<
int32_t
>
(
scale_weight
.
count
)
}};
nvinfer1
::
PluginFieldCollection
*
pluginPtr
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
pluginPtr
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
pluginPtr
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
pluginPtr
->
fields
=
fields
.
data
();
auto
pluginObj
=
creator
->
createPlugin
(
"CustomSkipLayerNormPluginDynamic"
,
pluginPtr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
inputs
.
data
(),
inputs
.
size
(),
*
pluginObj
);
PADDLE_ENFORCE_NE
(
plugin_layer
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"fail to add CustomSkipLayerNormPluginDynamic layer"
));
layer
=
plugin_layer
;
}
else
{
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomSkipLayerNormPluginDynamic"
,
"2"
);
PADDLE_ENFORCE_NE
(
creator
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"fail to get creator of CustomSkipLayerNormPluginDynamic"
));
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
);
int
ld
=
input1
->
getDimensions
().
d
[
2
];
// hidden dimension
PADDLE_ENFORCE_GT
(
ld
,
0
,
platform
::
errors
::
InvalidArgument
(
"in CustomSkipLayerNormPluginDynamic hidden "
"dimension should > 0"
));
if
(
enable_int8
)
{
type
=
static_cast
<
int
>
(
nvinfer1
::
DataType
::
kHALF
);
}
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"ld"
,
&
ld
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"beta"
,
bias_weight
.
values
,
GetPluginFieldType
(
bias_weight
.
type
),
static_cast
<
int32_t
>
(
bias_weight
.
count
)},
{
"gamma"
,
scale_weight
.
values
,
GetPluginFieldType
(
scale_weight
.
type
),
static_cast
<
int32_t
>
(
scale_weight
.
count
)},
};
nvinfer1
::
PluginFieldCollection
*
pluginPtr
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
*
pluginPtr
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
// remember to free
pluginPtr
->
nbFields
=
static_cast
<
int
>
(
fields
.
size
());
pluginPtr
->
fields
=
fields
.
data
();
auto
pluginObj
=
creator
->
createPlugin
(
"CustomSkipLayerNormPluginDynamic"
,
pluginPtr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
inputs
.
data
(),
inputs
.
size
(),
*
pluginObj
);
PADDLE_ENFORCE_NE
(
plugin_layer
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"fail to add CustomSkipLayerNormPluginDynamic layer"
));
layer
=
plugin_layer
;
if
(
flag_varseqlen
&&
engine_
->
with_interleaved
())
{
VLOG
(
4
)
<<
"fused skip_layernorm op: use_varseqlen and with_interleaved"
;
if
(
!
enable_int8
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
}
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomSkipLayerNormPluginDynamic"
,
"3"
);
PADDLE_ENFORCE_NE
(
creator
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"fail to get creator of CustomSkipLayerNormPluginDynamic"
));
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"beta"
,
bias_weight
.
values
,
GetPluginFieldType
(
bias_weight
.
type
),
static_cast
<
int32_t
>
(
bias_weight
.
count
)},
{
"gamma"
,
scale_weight
.
values
,
GetPluginFieldType
(
scale_weight
.
type
),
static_cast
<
int32_t
>
(
scale_weight
.
count
)}};
nvinfer1
::
PluginFieldCollection
*
pluginPtr
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
nvinfer1
::
PluginFieldCollection
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
pluginPtr
->
nbFields
=
static_cast
<
int32_t
>
(
fields
.
size
());
pluginPtr
->
fields
=
fields
.
data
();
auto
pluginObj
=
creator
->
createPlugin
(
"CustomSkipLayerNormPluginDynamic"
,
pluginPtr
);
free
(
pluginPtr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
inputs
.
data
(),
inputs
.
size
(),
*
pluginObj
);
PADDLE_ENFORCE_NE
(
plugin_layer
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"fail to add CustomSkipLayerNormPluginDynamic layer"
));
layer
=
plugin_layer
;
}
else
{
auto
GetFp16Weight
=
[
&
](
const
std
::
string
&
arg_name
)
->
TensorRTEngine
::
Weight
{
std
::
string
var_name
=
op_desc
.
Input
(
arg_name
).
front
();
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
weight
=
engine_
->
GetFp16TrtWeight
(
var_name
,
*
temp_tensor
);
return
weight
;
};
auto
GetFp32Weight
=
[
&
](
const
std
::
string
&
arg_name
)
->
TensorRTEngine
::
Weight
{
std
::
string
var_name
=
op_desc
.
Input
(
arg_name
).
front
();
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
weight
=
engine_
->
GetFp32TrtWeight
(
var_name
,
*
temp_tensor
);
return
weight
;
};
// bool with_fp16 = engine_->WithFp16() &&
// !engine_->disable_trt_plugin_fp16() &&
// (input1->getType() == nvinfer1::DataType::kHALF);
bool
with_fp16
=
false
;
TensorRTEngine
::
Weight
bias_weight
,
scale_weight
;
if
(
with_fp16
)
{
bias_weight
=
GetFp16Weight
(
"Bias"
);
scale_weight
=
GetFp16Weight
(
"Scale"
);
}
else
{
bias_weight
=
GetFp32Weight
(
"Bias"
);
scale_weight
=
GetFp32Weight
(
"Scale"
);
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomSkipLayerNormPluginDynamic"
,
"2"
);
PADDLE_ENFORCE_NE
(
creator
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"fail to get creator of CustomSkipLayerNormPluginDynamic"
));
int32_t
type
=
static_cast
<
int32_t
>
((
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
);
if
(
enable_int8
)
{
type
=
static_cast
<
int32_t
>
(
nvinfer1
::
DataType
::
kHALF
);
}
float
eps
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
plugin
::
SkipLayerNormPluginDynamic
*
plugin
=
new
plugin
::
SkipLayerNormPluginDynamic
(
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
bias_weight
.
get
().
values
)),
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
scale_weight
.
get
().
values
)),
bias_weight
.
get
().
count
,
scale_weight
.
get
().
count
,
eps
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
2
,
plugin
);
int32_t
hidden_size
=
PADDLE_GET_CONST
(
int32_t
,
op_desc
.
GetAttr
(
"hidden_size"
));
PADDLE_ENFORCE_GT
(
hidden_size
,
0
,
platform
::
errors
::
InvalidArgument
(
"in CustomSkipLayerNormPluginDynamic hidden "
"dimension should > 0"
));
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
{
"type_id"
,
&
type
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"ld"
,
&
hidden_size
,
nvinfer1
::
PluginFieldType
::
kINT32
,
1
},
{
"beta"
,
bias_weight
.
values
,
GetPluginFieldType
(
bias_weight
.
type
),
static_cast
<
int32_t
>
(
bias_weight
.
count
)},
{
"gamma"
,
scale_weight
.
values
,
GetPluginFieldType
(
scale_weight
.
type
),
static_cast
<
int32_t
>
(
scale_weight
.
count
)},
};
nvinfer1
::
PluginFieldCollection
*
pluginPtr
=
static_cast
<
nvinfer1
::
PluginFieldCollection
*>
(
malloc
(
sizeof
(
nvinfer1
::
PluginFieldCollection
)
+
fields
.
size
()
*
sizeof
(
nvinfer1
::
PluginField
)));
// remember to free
pluginPtr
->
nbFields
=
static_cast
<
int32_t
>
(
fields
.
size
());
pluginPtr
->
fields
=
fields
.
data
();
auto
pluginObj
=
creator
->
createPlugin
(
"CustomSkipLayerNormPluginDynamic"
,
pluginPtr
);
free
(
pluginPtr
);
auto
plugin_layer
=
engine_
->
network
()
->
addPluginV2
(
inputs
.
data
(),
inputs
.
size
(),
*
pluginObj
);
PADDLE_ENFORCE_NE
(
plugin_layer
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"fail to add CustomSkipLayerNormPluginDynamic layer"
));
layer
=
plugin_layer
;
}
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
"skip_layernorm"
,
{
output_name
},
test_mode
);
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"
));
#endif
}
};
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
b50dbe0b
...
...
@@ -12,7 +12,6 @@ list(
layer_norm_op_plugin.cu
instance_norm_op_plugin.cu
qkv_to_context_plugin.cu
skip_layernorm_op_plugin.cu
hard_swish_op_plugin.cu
stack_op_plugin.cu
anchor_generator_op_plugin.cu
...
...
paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.cu
已删除
100644 → 0
浏览文件 @
9df0ab32
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
template
<
typename
T
>
void
SkipLayerNormPluginDynamicImpl
<
T
>::
shareGPUData
(
const
SkipLayerNormPluginDynamicImplBase
*
anthor
)
{
auto
*
ptr
=
dynamic_cast
<
const
SkipLayerNormPluginDynamicImpl
<
T
>
*>
(
anthor
);
if
(
!
ptr
->
is_initialized_
)
{
return
;
}
scale_gpu_
=
ptr
->
scale_gpu_
;
bias_gpu_
=
ptr
->
bias_gpu_
;
}
template
<
typename
T
>
int
SkipLayerNormPluginDynamicImpl
<
T
>::
initialize
()
{
if
(
is_initialized_
)
{
return
0
;
}
if
(
bias_
)
{
cudaMalloc
(
&
bias_gpu_
,
sizeof
(
T
)
*
bias_size_
);
cudaMemcpy
(
bias_gpu_
,
bias_
,
bias_size_
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
}
if
(
scale_
)
{
cudaMalloc
(
&
scale_gpu_
,
sizeof
(
T
)
*
scale_size_
);
cudaMemcpy
(
scale_gpu_
,
scale_
,
scale_size_
*
sizeof
(
T
),
cudaMemcpyHostToDevice
);
}
is_initialized_
=
true
;
return
0
;
}
template
<
typename
T
>
void
SkipLayerNormPluginDynamicImpl
<
T
>::
terminate
()
{
if
(
bias_gpu_
)
{
cudaFree
(
bias_gpu_
);
bias_gpu_
=
nullptr
;
}
if
(
scale_gpu_
)
{
cudaFree
(
scale_gpu_
);
scale_gpu_
=
nullptr
;
}
}
int
SkipLayerNormPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
impl_
->
initialize
();
return
0
;
}
void
SkipLayerNormPluginDynamic
::
terminate
()
TRT_NOEXCEPT
{
impl_
->
terminate
();
}
nvinfer1
::
DimsExprs
SkipLayerNormPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
return
inputs
[
0
];
}
bool
SkipLayerNormPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of swish plugin shoule not be nullptr."
));
PADDLE_ENFORCE_EQ
(
nb_outputs
,
1
,
platform
::
errors
::
InvalidArgument
(
"The SkipLayerNorm's output should be one"
"but it's (%d) outputs."
,
nb_outputs
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
desc
=
in_out
[
pos
];
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return
(
desc
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
desc
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#else
return
(
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
desc
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
#endif
}
else
{
return
(
desc
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
desc
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
pos
-
1
];
if
(
pos
==
1
)
{
return
desc
.
type
==
prev
.
type
&&
desc
.
format
==
prev
.
format
;
}
// output
return
desc
.
type
==
prev
.
type
&&
desc
.
format
==
prev
.
format
;
}
nvinfer1
::
DataType
SkipLayerNormPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The SkipLayerNorm Plugin only has one output, so the "
"index value should be 0, but get %d."
,
index
));
PADDLE_ENFORCE_EQ
((
input_types
[
0
]
==
nvinfer1
::
DataType
::
kFLOAT
||
input_types
[
0
]
==
nvinfer1
::
DataType
::
kHALF
),
true
,
platform
::
errors
::
InvalidArgument
(
"The input type should be half or float"
));
return
input_types
[
0
];
}
template
<
typename
T
>
int
SkipLayerNormPluginDynamicImpl
<
T
>::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
input_dims
=
input_desc
[
0
].
dims
;
size_t
num
=
ProductDim
(
input_dims
);
int
hidden
=
input_dims
.
d
[
2
];
auto
input_type
=
input_desc
[
0
].
type
;
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
PADDLE_ENFORCE_EQ
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
,
true
,
platform
::
errors
::
InvalidArgument
(
"The SkipLayernorm Plugin only support fp32 input."
));
}
else
if
(
std
::
is_same
<
T
,
half
>::
value
)
{
PADDLE_ENFORCE_EQ
(
input_type
==
nvinfer1
::
DataType
::
kHALF
,
true
,
platform
::
errors
::
InvalidArgument
(
"The SkipLayernorm Plugin only support fp16 input."
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Unsupport data type, the out type of SkipLayernorm should be "
"float or half."
));
}
auto
*
output_d
=
reinterpret_cast
<
T
*>
(
outputs
[
0
]);
const
T
*
input1
=
reinterpret_cast
<
const
T
*>
(
inputs
[
0
]);
const
T
*
input2
=
reinterpret_cast
<
const
T
*>
(
inputs
[
1
]);
auto
*
output
=
reinterpret_cast
<
T
*>
(
outputs
[
0
]);
operators
::
math
::
SkipLayerNormFunctor
<
T
>
skip_layer_norm_func
;
skip_layer_norm_func
(
num
,
hidden
,
input1
,
input2
,
scale_gpu_
,
bias_gpu_
,
output
,
eps_
,
stream
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
int
SkipLayerNormPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
impl_
->
enqueue
(
input_desc
,
output_desc
,
inputs
,
outputs
,
workspace
,
stream
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h
已删除
100644 → 0
浏览文件 @
9df0ab32
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstddef>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/phi/common/data_type.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
#if IS_TRT_VERSION_GE(6000)
class
SkipLayerNormPluginDynamicImplBase
{
public:
SkipLayerNormPluginDynamicImplBase
()
{}
virtual
~
SkipLayerNormPluginDynamicImplBase
()
{}
virtual
int
initialize
()
=
0
;
virtual
void
terminate
()
=
0
;
virtual
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
=
0
;
virtual
void
shareGPUData
(
const
SkipLayerNormPluginDynamicImplBase
*
anthor
)
=
0
;
};
template
<
typename
T
>
class
SkipLayerNormPluginDynamicImpl
:
public
SkipLayerNormPluginDynamicImplBase
{
public:
explicit
SkipLayerNormPluginDynamicImpl
(
T
*
bias
,
T
*
scale
,
int
bias_size
,
int
scale_size
,
const
float
eps
)
:
bias_
(
bias
),
scale_
(
scale
),
bias_size_
(
bias_size
),
scale_size_
(
scale_size
),
eps_
(
eps
)
{}
~
SkipLayerNormPluginDynamicImpl
()
{}
int
initialize
();
void
terminate
();
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
;
void
shareGPUData
(
const
SkipLayerNormPluginDynamicImplBase
*
anthor
);
private:
T
*
bias_
{
nullptr
};
T
*
scale_
{
nullptr
};
// data on devices
T
*
bias_gpu_
{
nullptr
};
T
*
scale_gpu_
{
nullptr
};
int
bias_size_
;
int
scale_size_
;
float
eps_
;
bool
is_initialized_
{
false
};
};
class
SkipLayerNormPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
explicit
SkipLayerNormPluginDynamic
(
void
*
bias
,
void
*
scale
,
int
bias_size
,
int
scale_size
,
float
eps
,
bool
with_fp16
)
:
bias_
(
bias
),
scale_
(
scale
),
bias_size_
(
bias_size
),
scale_size_
(
scale_size
),
eps_
(
eps
),
own_host_buff_
(
false
)
{
with_fp16_
=
with_fp16
;
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG
(
1
)
<<
"TRT Plugin DataType selected. SkipLayerNorm-->fp16"
;
instantiateImpl
<
half
>
();
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.EnableTensorRtEngine(1 << 30, 1, 5, "
"AnalysisConfig::Precision::kFloat32, false, false) "
));
#endif
}
else
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. SkipLayerNorm-->fp32"
;
instantiateImpl
<
float
>
();
}
}
SkipLayerNormPluginDynamic
(
void
const
*
serial_data
,
size_t
serial_length
)
:
own_host_buff_
(
true
)
{
// the first var is with_fp16, we will use it.
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
with_fp16_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
bias_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
scale_size_
);
DeserializeValue
(
&
serial_data
,
&
serial_length
,
&
eps_
);
if
(
with_fp16_
)
{
if
(
bias_size_
)
{
bias_
=
new
half
[
bias_size_
];
memcpy
(
bias_
,
serial_data
,
sizeof
(
half
)
*
bias_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
bias_size_
*
sizeof
(
half
);
serial_length
-=
bias_size_
*
sizeof
(
half
);
if
(
scale_size_
)
{
scale_
=
new
half
[
scale_size_
];
memcpy
(
scale_
,
serial_data
,
sizeof
(
half
)
*
scale_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
scale_size_
*
sizeof
(
half
);
serial_length
-=
scale_size_
*
sizeof
(
half
);
}
else
{
if
(
bias_size_
)
{
bias_
=
new
float
[
bias_size_
];
memcpy
(
bias_
,
serial_data
,
sizeof
(
float
)
*
bias_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
bias_size_
*
sizeof
(
float
);
serial_length
-=
bias_size_
*
sizeof
(
float
);
if
(
scale_size_
)
{
scale_
=
new
float
[
scale_size_
];
memcpy
(
scale_
,
serial_data
,
sizeof
(
float
)
*
scale_size_
);
}
reinterpret_cast
<
char
const
*&>
(
serial_data
)
+=
scale_size_
*
sizeof
(
float
);
serial_length
-=
scale_size_
*
sizeof
(
float
);
}
if
(
with_fp16_
)
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
instantiateImpl
<
half
>
();
#else
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.EnableTensorRtEngine(1 << 30, 1, 5, "
"AnalysisConfig::Precision::kFloat32, false, false) "
));
#endif
}
else
{
instantiateImpl
<
float
>
();
}
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
auto
ptr
=
new
SkipLayerNormPluginDynamic
(
bias_
,
scale_
,
bias_size_
,
scale_size_
,
eps_
,
with_fp16_
);
ptr
->
shareGPUData
(
this
);
return
ptr
;
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"skip_layernorm_plugin"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
initialize
()
TRT_NOEXCEPT
override
;
void
terminate
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
size_t
sum_num
=
0
;
sum_num
+=
SerializedSize
(
with_fp16_
);
if
(
with_fp16_
)
{
sum_num
+=
(
bias_size_
+
scale_size_
)
*
sizeof
(
half
);
}
else
{
sum_num
+=
(
bias_size_
+
scale_size_
)
*
sizeof
(
float
);
}
sum_num
+=
SerializedSize
(
bias_size_
);
sum_num
+=
SerializedSize
(
scale_size_
);
sum_num
+=
SerializedSize
(
eps_
);
return
sum_num
;
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{
// the first var is for with_fp16, we will use it later;
SerializeValue
(
&
buffer
,
with_fp16_
);
SerializeValue
(
&
buffer
,
bias_size_
);
SerializeValue
(
&
buffer
,
scale_size_
);
SerializeValue
(
&
buffer
,
eps_
);
if
(
with_fp16_
)
{
for
(
int
i
=
0
;
i
<
bias_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
half
*>
(
bias_
)[
i
]);
}
for
(
int
i
=
0
;
i
<
scale_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
half
*>
(
scale_
)[
i
]);
}
}
else
{
for
(
int
i
=
0
;
i
<
bias_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
float
*>
(
bias_
)[
i
]);
}
for
(
int
i
=
0
;
i
<
scale_size_
;
++
i
)
{
SerializeValue
(
&
buffer
,
reinterpret_cast
<
float
*>
(
scale_
)[
i
]);
}
}
}
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nb_inputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nb_outputs
)
TRT_NOEXCEPT
override
{}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nb_inputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nb_outputs
)
const
TRT_NOEXCEPT
override
{
return
0
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
if
(
own_host_buff_
)
{
if
(
with_fp16_
)
{
delete
[]
reinterpret_cast
<
half
*>
(
bias_
);
delete
[]
reinterpret_cast
<
half
*>
(
scale_
);
}
else
{
delete
[]
reinterpret_cast
<
float
*>
(
bias_
);
delete
[]
reinterpret_cast
<
float
*>
(
scale_
);
}
}
delete
impl_
;
delete
this
;
}
private:
void
*
bias_
{
nullptr
};
void
*
scale_
{
nullptr
};
int
bias_size_
;
int
scale_size_
;
float
eps_
;
bool
own_host_buff_
{
false
};
SkipLayerNormPluginDynamicImplBase
*
impl_
{
nullptr
};
void
shareGPUData
(
const
SkipLayerNormPluginDynamic
*
anthor
)
{
impl_
->
shareGPUData
(
anthor
->
impl_
);
}
template
<
typename
U
>
void
instantiateImpl
()
{
impl_
=
new
SkipLayerNormPluginDynamicImpl
<
U
>
(
reinterpret_cast
<
U
*>
(
bias_
),
reinterpret_cast
<
U
*>
(
scale_
),
bias_size_
,
scale_size_
,
eps_
);
}
};
class
SkipLayerNormPluginDynamicCreator
:
public
nvinfer1
::
IPluginCreator
{
public:
SkipLayerNormPluginDynamicCreator
()
{}
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"skip_layernorm_plugin"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
const
nvinfer1
::
PluginFieldCollection
*
getFieldNames
()
TRT_NOEXCEPT
override
{
return
&
field_collection_
;
}
nvinfer1
::
IPluginV2
*
createPlugin
(
const
char
*
name
,
const
nvinfer1
::
PluginFieldCollection
*
fc
)
TRT_NOEXCEPT
override
{
return
nullptr
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
return
new
SkipLayerNormPluginDynamic
(
serial_data
,
serial_length
);
}
void
setPluginNamespace
(
const
char
*
lib_namespace
)
TRT_NOEXCEPT
override
{
plugin_namespace_
=
lib_namespace
;
}
const
char
*
getPluginNamespace
()
const
TRT_NOEXCEPT
override
{
return
plugin_namespace_
.
c_str
();
}
private:
std
::
string
plugin_namespace_
;
std
::
string
plugin_name_
;
nvinfer1
::
PluginFieldCollection
field_collection_
;
std
::
vector
<
nvinfer1
::
PluginField
>
plugin_attributes_
;
};
REGISTER_TRT_PLUGIN_V2
(
SkipLayerNormPluginDynamicCreator
);
#endif
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc
浏览文件 @
b50dbe0b
...
...
@@ -145,7 +145,7 @@ void trt_ernie(bool with_fp16,
TEST
(
AnalysisPredictor
,
no_fp16
)
{
std
::
vector
<
float
>
result
=
{
0.597841
,
0.219972
,
0.182187
};
trt_ernie
(
false
,
result
,
1e-
5
);
trt_ernie
(
false
,
result
,
1e-
4
);
}
TEST
(
AnalysisPredictor
,
fp16
)
{
...
...
@@ -158,7 +158,7 @@ TEST(AnalysisPredictor, fp16) {
TEST
(
AnalysisPredictor
,
no_fp16_bs2
)
{
std
::
vector
<
float
>
result
=
{
0.597841
,
0.219972
,
0.182187
,
0.597841
,
0.219972
,
0.182187
};
trt_ernie
(
false
,
result
,
1e-
5
,
2
);
trt_ernie
(
false
,
result
,
1e-
4
,
2
);
}
TEST
(
AnalysisPredictor
,
fp16_bs2
)
{
...
...
python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py
浏览文件 @
b50dbe0b
...
...
@@ -36,7 +36,7 @@ class InferencePassTest(unittest.TestCase):
self
.
enable_mkldnn
=
False
self
.
enable_mkldnn_bfloat16
=
False
self
.
enable_trt
=
False
self
.
enable_tensorrt_varseqlen
=
Tru
e
self
.
enable_tensorrt_varseqlen
=
Fals
e
self
.
trt_parameters
=
None
self
.
dynamic_shape_params
=
None
self
.
enable_lite
=
False
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_skip_layernorm.py
已删除
100644 → 0
浏览文件 @
9df0ab32
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
functools
import
partial
from
typing
import
Any
,
Dict
,
List
import
numpy
as
np
from
program_config
import
ProgramConfig
,
TensorConfig
from
trt_layer_auto_scan_test
import
TrtLayerAutoScanTest
import
paddle.inference
as
paddle_infer
class
TrtConvertSkipLayernormTest
(
TrtLayerAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
inputs
=
program_config
.
inputs
weights
=
program_config
.
weights
outputs
=
program_config
.
outputs
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# The input dimension should be less than or equal to the set axis.
if
attrs
[
0
][
'begin_norm_axis'
]
>=
0
:
if
(
len
(
inputs
[
'skip_layernorm_inputX_data'
].
shape
)
<=
attrs
[
0
][
'begin_norm_axis'
]
):
return
False
# 2D input is not supported.
if
self
.
dims
==
2
:
return
False
return
True
def
sample_program_configs
(
self
):
def
generate_input1
(
attrs
:
List
[
Dict
[
str
,
Any
]],
batch
):
if
self
.
dims
==
4
:
return
np
.
ones
([
batch
,
6
,
128
,
768
]).
astype
(
np
.
float32
)
elif
self
.
dims
==
3
:
return
np
.
ones
([
batch
,
128
,
768
]).
astype
(
np
.
float32
)
elif
self
.
dims
==
2
:
return
np
.
ones
([
batch
,
768
]).
astype
(
np
.
float32
)
def
generate_input2
(
attrs
:
List
[
Dict
[
str
,
Any
]],
batch
):
if
self
.
dims
==
4
:
return
np
.
ones
([
batch
,
6
,
128
,
768
]).
astype
(
np
.
float32
)
elif
self
.
dims
==
3
:
return
np
.
ones
([
batch
,
128
,
768
]).
astype
(
np
.
float32
)
elif
self
.
dims
==
2
:
return
np
.
ones
([
batch
,
768
]).
astype
(
np
.
float32
)
def
generate_weight1
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
random
.
random
([
768
]).
astype
(
np
.
float32
)
def
generate_weight2
(
attrs
:
List
[
Dict
[
str
,
Any
]]):
return
np
.
random
.
random
([
768
]).
astype
(
np
.
float32
)
for
dims
in
[
2
,
3
,
4
]:
for
batch
in
[
1
,
2
,
4
]:
for
epsilon
in
[
1e-5
]:
for
begin_norm_axis
in
[
0
,
1
,
2
,
-
1
]:
for
enable_int8
in
[
False
,
True
]:
self
.
dims
=
dims
dics
=
[
{
"epsilon"
:
epsilon
,
"begin_norm_axis"
:
begin_norm_axis
,
"enable_int8"
:
enable_int8
,
},
{},
]
ops_config
=
[
{
"op_type"
:
"skip_layernorm"
,
"op_inputs"
:
{
"X"
:
[
"skip_layernorm_inputX_data"
],
"Y"
:
[
"skip_layernorm_inputY_data"
],
"Bias"
:
[
"Bias"
],
"Scale"
:
[
"Scale"
],
},
"op_outputs"
:
{
"Out"
:
[
"skip_layernorm_out"
]
},
"op_attrs"
:
dics
[
0
],
}
]
ops
=
self
.
generate_op_config
(
ops_config
)
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{
"Bias"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight1
,
dics
)
),
"Scale"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight2
,
dics
)
),
},
inputs
=
{
"skip_layernorm_inputX_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_input1
,
dics
,
batch
)
),
"skip_layernorm_inputY_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_input2
,
dics
,
batch
)
),
},
outputs
=
[
"skip_layernorm_out"
],
)
yield
program_config
def
sample_predictor_configs
(
self
,
program_config
)
->
(
paddle_infer
.
Config
,
List
[
int
],
float
):
def
generate_dynamic_shape
(
attrs
):
if
self
.
dims
==
4
:
self
.
dynamic_shape
.
min_input_shape
=
{
"skip_layernorm_inputX_data"
:
[
1
,
6
,
128
,
768
],
"skip_layernorm_inputY_data"
:
[
1
,
6
,
128
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
],
}
self
.
dynamic_shape
.
max_input_shape
=
{
"skip_layernorm_inputX_data"
:
[
4
,
6
,
768
,
3072
],
"skip_layernorm_inputY_data"
:
[
4
,
6
,
768
,
3072
],
"Bias"
:
[
3072
],
"Scale"
:
[
3072
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"skip_layernorm_inputX_data"
:
[
2
,
6
,
128
,
768
],
"skip_layernorm_inputY_data"
:
[
2
,
6
,
128
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
],
}
elif
self
.
dims
==
3
:
self
.
dynamic_shape
.
min_input_shape
=
{
"skip_layernorm_inputX_data"
:
[
1
,
128
,
768
],
"skip_layernorm_inputY_data"
:
[
1
,
128
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
],
}
self
.
dynamic_shape
.
max_input_shape
=
{
"skip_layernorm_inputX_data"
:
[
4
,
768
,
3072
],
"skip_layernorm_inputY_data"
:
[
4
,
768
,
3072
],
"Bias"
:
[
3072
],
"Scale"
:
[
3072
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"skip_layernorm_inputX_data"
:
[
2
,
128
,
768
],
"skip_layernorm_inputY_data"
:
[
2
,
128
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
],
}
elif
self
.
dims
==
2
:
self
.
dynamic_shape
.
min_input_shape
=
{
"skip_layernorm_inputX_data"
:
[
1
,
768
],
"skip_layernorm_inputY_data"
:
[
1
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
],
}
self
.
dynamic_shape
.
max_input_shape
=
{
"skip_layernorm_inputX_data"
:
[
4
,
3072
],
"skip_layernorm_inputY_data"
:
[
4
,
3072
],
"Bias"
:
[
3072
],
"Scale"
:
[
3072
],
}
self
.
dynamic_shape
.
opt_input_shape
=
{
"skip_layernorm_inputX_data"
:
[
2
,
768
],
"skip_layernorm_inputY_data"
:
[
2
,
768
],
"Bias"
:
[
768
],
"Scale"
:
[
768
],
}
def
clear_dynamic_shape
():
self
.
dynamic_shape
.
min_input_shape
=
{}
self
.
dynamic_shape
.
max_input_shape
=
{}
self
.
dynamic_shape
.
opt_input_shape
=
{}
def
generate_trt_nodes_num
(
attrs
,
dynamic_shape
):
if
dynamic_shape
:
return
1
,
3
else
:
return
0
,
4
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# # for static_shape
# clear_dynamic_shape()
# self.trt_param.precision = paddle_infer.PrecisionType.Float32
# yield self.create_inference_config(), generate_trt_nodes_num(
# attrs, False), 1e-5
# self.trt_param.precision = paddle_infer.PrecisionType.Half
# yield self.create_inference_config(), generate_trt_nodes_num(
# attrs, False), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape
(
attrs
)
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Float32
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
1e-5
self
.
trt_param
.
precision
=
paddle_infer
.
PrecisionType
.
Half
yield
self
.
create_inference_config
(),
generate_trt_nodes_num
(
attrs
,
True
),
(
1e-3
,
1e-3
)
def
add_skip_trt_case
(
self
):
pass
def
test
(
self
):
self
.
add_skip_trt_case
()
self
.
run_test
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ir/inference/test_trt_skip_layernorm_fuse_pass.py
0 → 100644
浏览文件 @
b50dbe0b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
import
unittest
import
numpy
as
np
from
inference_pass_test
import
InferencePassTest
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.core
import
AnalysisConfig
,
PassVersionChecker
class
SkipLayernormFusePassTest0
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data1
=
paddle
.
static
.
data
(
name
=
"data1"
,
shape
=
[
-
1
,
3
,
128
,
128
],
dtype
=
"float32"
)
data2
=
paddle
.
static
.
data
(
name
=
"data2"
,
shape
=
[
-
1
,
3
,
128
,
128
],
dtype
=
"float32"
)
eltwise_out
=
self
.
append_eltwise
(
data1
,
data2
)
out
=
paddle
.
nn
.
functional
.
layer_norm
(
eltwise_out
,
eltwise_out
.
shape
[
1
:]
)
self
.
feeds
=
{
"data1"
:
np
.
random
.
random
([
1
,
3
,
128
,
128
]).
astype
(
"float32"
),
"data2"
:
np
.
random
.
random
([
1
,
3
,
128
,
128
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
SkipLayernormFusePassTest0
.
TensorRTParam
(
1
<<
30
,
32
,
0
,
AnalysisConfig
.
Precision
.
Float32
,
True
,
False
)
self
.
dynamic_shape_params
=
(
SkipLayernormFusePassTest0
.
DynamicShapeParam
(
{
'data1'
:
[
1
,
1
,
1
,
128
],
'data2'
:
[
1
,
1
,
1
,
128
]},
{
'data1'
:
[
1
,
3
,
128
,
128
],
'data2'
:
[
1
,
3
,
128
,
128
]},
{
'data1'
:
[
1
,
3
,
128
,
128
],
'data2'
:
[
1
,
3
,
128
,
128
]},
False
,
)
)
self
.
fetch_list
=
[
out
]
def
append_eltwise
(
self
,
data1
,
data2
):
return
paddle
.
add
(
data1
,
data2
)
def
test_check_output
(
self
):
if
os
.
path
.
exists
(
self
.
path
+
"_opt_cache"
):
shutil
.
rmtree
(
self
.
path
+
"_opt_cache"
)
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
,
atol
=
0.01
,
rtol
=
0.00001
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
)
)
class
SkipLayernormFusePassTest1
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data1
=
paddle
.
static
.
data
(
name
=
"data1"
,
shape
=
[
-
1
,
256
,
1536
],
dtype
=
"float32"
)
data2
=
paddle
.
static
.
data
(
name
=
"data2"
,
shape
=
[
-
1
,
256
,
1536
],
dtype
=
"float32"
)
eltwise_out
=
self
.
append_eltwise
(
data1
,
data2
)
out
=
paddle
.
nn
.
functional
.
layer_norm
(
eltwise_out
,
eltwise_out
.
shape
[
1
:]
)
self
.
feeds
=
{
"data1"
:
np
.
random
.
random
([
1
,
256
,
1536
]).
astype
(
"float32"
),
"data2"
:
np
.
random
.
random
([
1
,
256
,
1536
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
SkipLayernormFusePassTest1
.
TensorRTParam
(
1
<<
30
,
32
,
0
,
AnalysisConfig
.
Precision
.
Float32
,
True
,
False
)
self
.
dynamic_shape_params
=
(
SkipLayernormFusePassTest1
.
DynamicShapeParam
(
{
'data1'
:
[
1
,
1
,
1
],
'data2'
:
[
1
,
1
,
1
]},
{
'data1'
:
[
1
,
384
,
1536
],
'data2'
:
[
1
,
384
,
1536
]},
{
'data1'
:
[
1
,
384
,
1536
],
'data2'
:
[
1
,
384
,
1536
]},
False
,
)
)
self
.
fetch_list
=
[
out
]
def
append_eltwise
(
self
,
data1
,
data2
):
return
paddle
.
add
(
data1
,
data2
)
def
test_check_output
(
self
):
if
os
.
path
.
exists
(
self
.
path
+
"_opt_cache"
):
shutil
.
rmtree
(
self
.
path
+
"_opt_cache"
)
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
,
atol
=
0.01
,
rtol
=
0.00001
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
)
)
class
SkipLayernormFusePassTest2
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data1
=
paddle
.
static
.
data
(
name
=
"data1"
,
shape
=
[
-
1
,
128
,
64
,
768
],
dtype
=
"float32"
)
data2
=
paddle
.
static
.
data
(
name
=
"data2"
,
shape
=
[
-
1
,
128
,
64
,
768
],
dtype
=
"float32"
)
eltwise_out
=
self
.
append_eltwise
(
data1
,
data2
)
out
=
paddle
.
nn
.
functional
.
layer_norm
(
eltwise_out
,
eltwise_out
.
shape
[
1
:]
)
self
.
feeds
=
{
"data1"
:
np
.
random
.
random
([
1
,
128
,
64
,
768
]).
astype
(
"float32"
),
"data2"
:
np
.
random
.
random
([
1
,
128
,
64
,
768
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
SkipLayernormFusePassTest2
.
TensorRTParam
(
1
<<
30
,
32
,
0
,
AnalysisConfig
.
Precision
.
Half
,
True
,
False
)
self
.
dynamic_shape_params
=
(
SkipLayernormFusePassTest2
.
DynamicShapeParam
(
{
'data1'
:
[
1
,
1
,
1
,
1
],
'data2'
:
[
1
,
1
,
1
,
1
]},
{
'data1'
:
[
1
,
128
,
64
,
768
],
'data2'
:
[
1
,
128
,
64
,
768
]},
{
'data1'
:
[
1
,
128
,
64
,
768
],
'data2'
:
[
1
,
128
,
64
,
768
]},
False
,
)
)
self
.
fetch_list
=
[
out
]
def
append_eltwise
(
self
,
data1
,
data2
):
return
paddle
.
add
(
data1
,
data2
)
def
test_check_output
(
self
):
if
os
.
path
.
exists
(
self
.
path
+
"_opt_cache"
):
shutil
.
rmtree
(
self
.
path
+
"_opt_cache"
)
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
,
atol
=
0.1
,
rtol
=
0.00001
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
)
)
class
SkipLayernormFusePassTest3
(
InferencePassTest
):
def
setUp
(
self
):
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
data1
=
paddle
.
static
.
data
(
name
=
"data1"
,
shape
=
[
-
1
,
128
,
128
],
dtype
=
"float32"
)
data2
=
paddle
.
static
.
data
(
name
=
"data2"
,
shape
=
[
-
1
,
128
,
128
],
dtype
=
"float32"
)
eltwise_out
=
self
.
append_eltwise
(
data1
,
data2
)
out
=
paddle
.
nn
.
functional
.
layer_norm
(
eltwise_out
,
eltwise_out
.
shape
[
1
:]
)
self
.
feeds
=
{
"data1"
:
np
.
random
.
random
([
1
,
128
,
128
]).
astype
(
"float32"
),
"data2"
:
np
.
random
.
random
([
1
,
128
,
128
]).
astype
(
"float32"
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
SkipLayernormFusePassTest3
.
TensorRTParam
(
1
<<
30
,
32
,
0
,
AnalysisConfig
.
Precision
.
Half
,
True
,
False
)
self
.
dynamic_shape_params
=
(
SkipLayernormFusePassTest3
.
DynamicShapeParam
(
{
'data1'
:
[
1
,
1
,
1
],
'data2'
:
[
1
,
1
,
1
]},
{
'data1'
:
[
1
,
128
,
128
],
'data2'
:
[
1
,
128
,
128
]},
{
'data1'
:
[
1
,
128
,
128
],
'data2'
:
[
1
,
128
,
128
]},
False
,
)
)
self
.
fetch_list
=
[
out
]
def
append_eltwise
(
self
,
data1
,
data2
):
return
paddle
.
add
(
data1
,
data2
)
def
test_check_output
(
self
):
if
os
.
path
.
exists
(
self
.
path
+
"_opt_cache"
):
shutil
.
rmtree
(
self
.
path
+
"_opt_cache"
)
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
,
atol
=
0.1
,
rtol
=
0.00001
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
)
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录