Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1bec83f4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1bec83f4
编写于
8月 10, 2022
作者:
W
Wangzheee
提交者:
GitHub
8月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
disable_skip_layernorm_fp16 (#45041)
上级
9a04540c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
18 addition
and
10 deletion
+18
-10
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
+18
-10
未找到文件。
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
浏览文件 @
1bec83f4
...
@@ -22,7 +22,8 @@ namespace tensorrt {
...
@@ -22,7 +22,8 @@ namespace tensorrt {
class
SkipLayerNormOpConverter
:
public
OpConverter
{
class
SkipLayerNormOpConverter
:
public
OpConverter
{
public:
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
#if IS_TRT_VERSION_GE(6000)
#if IS_TRT_VERSION_GE(6000)
VLOG
(
4
)
<<
"convert fused skip layernorm op to tensorrt layer"
;
VLOG
(
4
)
<<
"convert fused skip layernorm op to tensorrt layer"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
...
@@ -63,7 +64,8 @@ class SkipLayerNormOpConverter : public OpConverter {
...
@@ -63,7 +64,8 @@ class SkipLayerNormOpConverter : public OpConverter {
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomSkipLayerNormPluginDynamic"
,
"3"
);
"CustomSkipLayerNormPluginDynamic"
,
"3"
);
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
creator
,
nullptr
,
creator
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"fail to get creator of CustomSkipLayerNormPluginDynamic"
));
"fail to get creator of CustomSkipLayerNormPluginDynamic"
));
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
const
std
::
vector
<
nvinfer1
::
PluginField
>
fields
{
...
@@ -85,7 +87,8 @@ class SkipLayerNormOpConverter : public OpConverter {
...
@@ -85,7 +87,8 @@ class SkipLayerNormOpConverter : public OpConverter {
inputs
.
data
(),
inputs
.
size
(),
*
pluginObj
);
inputs
.
data
(),
inputs
.
size
(),
*
pluginObj
);
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
plugin_layer
,
nullptr
,
plugin_layer
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"fail to add CustomSkipLayerNormPluginDynamic layer"
));
"fail to add CustomSkipLayerNormPluginDynamic layer"
));
layer
=
plugin_layer
;
layer
=
plugin_layer
;
...
@@ -93,14 +96,16 @@ class SkipLayerNormOpConverter : public OpConverter {
...
@@ -93,14 +96,16 @@ class SkipLayerNormOpConverter : public OpConverter {
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomSkipLayerNormPluginDynamic"
,
"2"
);
"CustomSkipLayerNormPluginDynamic"
,
"2"
);
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
creator
,
nullptr
,
creator
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"fail to get creator of CustomSkipLayerNormPluginDynamic"
));
"fail to get creator of CustomSkipLayerNormPluginDynamic"
));
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
);
:
nvinfer1
::
DataType
::
kFLOAT
);
int
ld
=
input1
->
getDimensions
().
d
[
2
];
// hidden dimension
int
ld
=
input1
->
getDimensions
().
d
[
2
];
// hidden dimension
PADDLE_ENFORCE_GT
(
ld
,
0
,
PADDLE_ENFORCE_GT
(
ld
,
0
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"in CustomSkipLayerNormPluginDynamic hidden "
"in CustomSkipLayerNormPluginDynamic hidden "
"dimension should > 0"
));
"dimension should > 0"
));
...
@@ -128,18 +133,21 @@ class SkipLayerNormOpConverter : public OpConverter {
...
@@ -128,18 +133,21 @@ class SkipLayerNormOpConverter : public OpConverter {
inputs
.
data
(),
inputs
.
size
(),
*
pluginObj
);
inputs
.
data
(),
inputs
.
size
(),
*
pluginObj
);
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
plugin_layer
,
nullptr
,
plugin_layer
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"fail to add CustomSkipLayerNormPluginDynamic layer"
));
"fail to add CustomSkipLayerNormPluginDynamic layer"
));
layer
=
plugin_layer
;
layer
=
plugin_layer
;
}
}
}
else
{
}
else
{
float
eps
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
float
eps
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
bool
with_fp16
=
/* bool with_fp16 =
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
*/
bool
with_fp16
=
false
;
plugin
::
SkipLayerNormPluginDynamic
*
plugin
=
plugin
::
SkipLayerNormPluginDynamic
*
plugin
=
new
plugin
::
SkipLayerNormPluginDynamic
(
bias
,
scale
,
bias_size
,
new
plugin
::
SkipLayerNormPluginDynamic
(
scale_size
,
eps
,
with_fp16
);
bias
,
scale
,
bias_size
,
scale_size
,
eps
,
with_fp16
);
layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
2
,
plugin
);
layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
2
,
plugin
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录