Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
591be3bd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
591be3bd
编写于
1月 09, 2023
作者:
W
wenbin
提交者:
GitHub
1月 09, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Preln groupnorm (#49463)
* skip_groupnorm * init * preln * add ut * more assert * set timeout * fix windows ci issue
上级
aaa25222
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
2580 addition
and
6 deletion
+2580
-6
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.cc
paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.cc
+204
-0
paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.h
paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.h
+98
-0
paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.cc
...luid/framework/ir/preln_elementwise_groupnorm_act_pass.cc
+196
-0
paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h
...fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h
+96
-0
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+2
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+9
-4
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+2
-0
paddle/fluid/inference/tensorrt/convert/preln_groupnorm_act_op.cc
...luid/inference/tensorrt/convert/preln_groupnorm_act_op.cc
+94
-0
paddle/fluid/inference/tensorrt/convert/skip_groupnorm_act_op.cc
...fluid/inference/tensorrt/convert/skip_groupnorm_act_op.cc
+92
-0
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+22
-2
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+2
-0
paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h
.../inference/tensorrt/plugin/common/groupNormPluginCommon.h
+2
-0
paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu
...nference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu
+456
-0
paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h
...inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h
+194
-0
paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu
...inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu
+463
-0
paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h
.../inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h
+194
-0
paddle/fluid/operators/compat/group_norm.pbtxt
paddle/fluid/operators/compat/group_norm.pbtxt
+67
-0
paddle/fluid/operators/compat/silu.pbtxt
paddle/fluid/operators/compat/silu.pbtxt
+31
-0
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
.../paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
+8
-0
python/paddle/fluid/tests/unittests/ir/inference/test_element_groupnorm_act_fuse_pass.py
...ests/ir/inference/test_element_groupnorm_act_fuse_pass.py
+173
-0
python/paddle/fluid/tests/unittests/ir/inference/test_preln_groupnorm_act_fuse_pass.py
...ttests/ir/inference/test_preln_groupnorm_act_fuse_pass.py
+173
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
591be3bd
...
...
@@ -139,6 +139,8 @@ if(WITH_TENSORRT)
pass_library
(
layernorm_shift_partition_fuse_pass inference
)
pass_library
(
reverse_roll_fuse_pass inference
)
pass_library
(
preln_layernorm_x_fuse_pass inference
)
pass_library
(
elementwise_groupnorm_act_pass inference
)
pass_library
(
preln_elementwise_groupnorm_act_pass inference
)
endif
()
if
(
WITH_TENSORRT
)
...
...
paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.cc
0 → 100644
浏览文件 @
591be3bd
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Node
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
SkipGroupNormAct
:
public
PatternBase
{
SkipGroupNormAct
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"skip_groupnorm_act"
)
{}
void
operator
()(
PDNode
*
x
,
PDNode
*
y
);
// declare operator node's name
PATTERN_DECL_NODE
(
elementwise
);
PATTERN_DECL_NODE
(
group_norm
);
// declare variable node's name
PATTERN_DECL_NODE
(
elementwise_out
);
PATTERN_DECL_NODE
(
group_norm_bias
);
PATTERN_DECL_NODE
(
group_norm_scale
);
PATTERN_DECL_NODE
(
group_norm_out
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
act_out
);
};
void
SkipGroupNormAct
::
operator
()(
PDNode
*
x
,
PDNode
*
y
)
{
auto
*
elementwise
=
pattern
->
NewNode
(
elementwise_repr
())
->
assert_is_op
(
"elementwise_add"
)
->
assert_has_n_outputs
(
1
);
auto
*
elementwise_out_var
=
pattern
->
NewNode
(
elementwise_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_has_n_outputs
(
1
)
->
assert_is_op_input
(
"group_norm"
,
"X"
);
elementwise
->
LinksFrom
({
x
,
y
}).
LinksTo
({
elementwise_out_var
});
// Create nodes for group_norm op.
auto
*
group_norm
=
pattern
->
NewNode
(
group_norm_repr
())
->
assert_is_op
(
"group_norm"
);
auto
*
group_norm_bias_var
=
pattern
->
NewNode
(
group_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"group_norm"
,
"Bias"
);
auto
*
group_norm_scale_var
=
pattern
->
NewNode
(
group_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"group_norm"
,
"Scale"
);
auto
*
group_norm_out_var
=
pattern
->
NewNode
(
group_norm_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"group_norm"
,
"Y"
)
->
assert_is_op_input
(
"silu"
,
"X"
);
// Add links for group_norm op.
group_norm
->
LinksFrom
(
{
elementwise_out_var
,
group_norm_bias_var
,
group_norm_scale_var
})
.
LinksTo
({
group_norm_out_var
});
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
"silu"
);
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"silu"
,
"Out"
);
act
->
LinksFrom
({
group_norm_out_var
}).
LinksTo
({
act_out
});
}
}
// namespace patterns
int
SkipGroupNormActFusePass
::
ApplyGNSiluPattern
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
"skip_groupnorm_silu_fuse"
,
graph
);
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
PDNode
*
x
=
nullptr
;
PDNode
*
y
=
nullptr
;
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"skip_groupnorm_act_fuse/x"
)
->
AsInput
()
->
assert_var_not_persistable
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
y
=
gpd
.
mutable_pattern
()
->
NewNode
(
"skip_groupnorm_act_fuse/y"
)
->
AsInput
()
->
assert_var_not_persistable
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
)
->
assert_more
([
&
](
Node
*
x
)
{
auto
shape
=
x
->
Var
()
->
GetShape
();
if
(
shape
.
size
()
==
2
||
(
shape
.
size
()
==
4
&&
shape
[
3
]
==
1
&&
shape
[
2
]
==
1
))
return
true
;
else
return
false
;
});
patterns
::
SkipGroupNormAct
fused_pattern
(
gpd
.
mutable_pattern
(),
"skip_groupnorm_act_fuse"
);
fused_pattern
(
x
,
y
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
subgraph
.
count
(
x
)
<=
0
||
subgraph
.
count
(
y
)
<=
0
)
{
LOG
(
WARNING
)
<<
"The subgraph is empty."
;
return
;
}
VLOG
(
4
)
<<
"handle skip groupnorm act fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise
,
elementwise
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
group_norm
,
group_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
group_norm_bias
,
group_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
group_norm_scale
,
group_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
group_norm_out
,
group_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
fused_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"skip groupnorm act pass in op compat failed."
;
return
;
}
std
::
unordered_set
<
const
Node
*>
del_node_set
;
// Create an skip_groupnorm_act op node
OpDesc
new_desc
(
*
group_norm
->
Op
());
new_desc
.
SetType
(
"skip_groupnorm_act"
);
new_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
new_desc
.
SetInput
(
"Y"
,
{
subgraph
.
at
(
y
)
->
Name
()});
new_desc
.
SetOutput
(
"Out"
,
{
act_out
->
Name
()});
new_desc
.
RemoveOutput
(
"Y"
);
new_desc
.
Flush
();
auto
fused_node
=
graph
->
CreateOpNode
(
&
new_desc
);
// OpDesc will be copied.
del_node_set
.
insert
(
elementwise
);
del_node_set
.
insert
(
group_norm
);
del_node_set
.
insert
(
elementwise_out
);
del_node_set
.
insert
(
group_norm_out
);
del_node_set
.
insert
(
act
);
GraphSafeRemoveNodes
(
graph
,
del_node_set
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fused_node
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
y
),
fused_node
);
IR_NODE_LINK_TO
(
group_norm_scale
,
fused_node
);
IR_NODE_LINK_TO
(
group_norm_bias
,
fused_node
);
IR_NODE_LINK_TO
(
fused_node
,
act_out
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
return
found_subgraph_count
;
}
void
SkipGroupNormActFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
"skip_groupnorm_act_fuse_pass"
,
graph
);
int
found_subgraph_count
=
ApplyGNSiluPattern
(
graph
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
elementwise_groupnorm_act_pass
,
paddle
::
framework
::
ir
::
SkipGroupNormActFusePass
);
REGISTER_PASS_CAPABILITY
(
elementwise_groupnorm_act_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"silu"
,
0
)
.
EQ
(
"group_norm"
,
0
));
paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.h
0 → 100644
浏览文件 @
591be3bd
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
//
//
//
// | |
// elementwise_add fuse | |
// | -> skip_gn_act
// group_norm |
// |
// silu
// |
class
Graph
;
class
SkipGroupNormActFusePass
:
public
FusePassBase
{
public:
SkipGroupNormActFusePass
()
{
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
0
,
-
1
})
.
End
();
AddOpCompat
(
OpCompat
(
"group_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
1.0
f
)
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"data_layout"
)
.
IsStringIn
({
"NCHW"
})
.
End
();
AddOpCompat
(
OpCompat
(
"silu"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
();
}
virtual
~
SkipGroupNormActFusePass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
int
ApplyGNSiluPattern
(
ir
::
Graph
*
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.cc
0 → 100644
浏览文件 @
591be3bd
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Node
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
PrelnGroupNormAct
:
public
PatternBase
{
PrelnGroupNormAct
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"preln_groupnorm_act"
)
{}
void
operator
()(
PDNode
*
x
,
PDNode
*
y
);
// declare operator node's name
PATTERN_DECL_NODE
(
elementwise
);
PATTERN_DECL_NODE
(
group_norm
);
// declare variable node's name
PATTERN_DECL_NODE
(
elementwise_out
);
PATTERN_DECL_NODE
(
group_norm_bias
);
PATTERN_DECL_NODE
(
group_norm_scale
);
PATTERN_DECL_NODE
(
group_norm_out
);
PATTERN_DECL_NODE
(
act
);
PATTERN_DECL_NODE
(
act_out
);
};
void
PrelnGroupNormAct
::
operator
()(
PDNode
*
x
,
PDNode
*
y
)
{
auto
*
elementwise
=
pattern
->
NewNode
(
elementwise_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
elementwise_out_var
=
pattern
->
NewNode
(
elementwise_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
)
->
assert_is_op_input
(
"group_norm"
,
"X"
);
elementwise
->
LinksFrom
({
x
,
y
}).
LinksTo
({
elementwise_out_var
});
// Create nodes for group_norm op.
auto
*
group_norm
=
pattern
->
NewNode
(
group_norm_repr
())
->
assert_is_op
(
"group_norm"
);
auto
*
group_norm_bias_var
=
pattern
->
NewNode
(
group_norm_bias_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"group_norm"
,
"Bias"
);
auto
*
group_norm_scale_var
=
pattern
->
NewNode
(
group_norm_scale_repr
())
->
AsInput
()
->
assert_is_persistable_var
()
->
assert_is_op_input
(
"group_norm"
,
"Scale"
);
auto
*
group_norm_out_var
=
pattern
->
NewNode
(
group_norm_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"group_norm"
,
"Y"
)
->
assert_is_op_input
(
"silu"
,
"X"
);
// Add links for group_norm op.
group_norm
->
LinksFrom
(
{
elementwise_out_var
,
group_norm_bias_var
,
group_norm_scale_var
})
.
LinksTo
({
group_norm_out_var
});
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
"silu"
);
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"silu"
,
"Out"
);
act
->
LinksFrom
({
group_norm_out_var
}).
LinksTo
({
act_out
});
}
}
// namespace patterns
int
PrelnGroupNormActFusePass
::
ApplyGNSiluPattern
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
"preln_groupnorm_silu_fuse"
,
graph
);
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
PDNode
*
x
=
nullptr
;
PDNode
*
y
=
nullptr
;
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"preln_groupnorm_act_fuse/x"
)
->
AsInput
()
->
assert_var_not_persistable
()
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
y
=
gpd
.
mutable_pattern
()
->
NewNode
(
"preln_groupnorm_act_fuse/y"
)
->
AsInput
()
->
assert_var_not_persistable
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
patterns
::
PrelnGroupNormAct
fused_pattern
(
gpd
.
mutable_pattern
(),
"preln_groupnorm_act_fuse"
);
fused_pattern
(
x
,
y
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
if
(
subgraph
.
count
(
x
)
<=
0
||
subgraph
.
count
(
y
)
<=
0
)
{
LOG
(
WARNING
)
<<
"The subgraph is empty."
;
return
;
}
VLOG
(
4
)
<<
"handle preln groupnorm act fuse"
;
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise
,
elementwise
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
group_norm
,
group_norm
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
group_norm_bias
,
group_norm_bias
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
group_norm_scale
,
group_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
group_norm_out
,
group_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
fused_pattern
);
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"preln groupnorm act pass in op compat failed."
;
return
;
}
std
::
unordered_set
<
const
Node
*>
del_node_set
;
// Create an preln_groupnorm_act op node
OpDesc
new_desc
(
*
group_norm
->
Op
());
new_desc
.
SetType
(
"preln_groupnorm_act"
);
new_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
new_desc
.
SetInput
(
"Y"
,
{
subgraph
.
at
(
y
)
->
Name
()});
new_desc
.
SetOutput
(
"Out_0"
,
{
elementwise_out
->
Name
()});
new_desc
.
SetOutput
(
"Out_1"
,
{
act_out
->
Name
()});
new_desc
.
RemoveOutput
(
"Y"
);
new_desc
.
Flush
();
auto
fused_node
=
graph
->
CreateOpNode
(
&
new_desc
);
// OpDesc will be copied.
del_node_set
.
insert
(
elementwise
);
del_node_set
.
insert
(
group_norm
);
del_node_set
.
insert
(
group_norm_out
);
del_node_set
.
insert
(
act
);
GraphSafeRemoveNodes
(
graph
,
del_node_set
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fused_node
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
y
),
fused_node
);
IR_NODE_LINK_TO
(
group_norm_scale
,
fused_node
);
IR_NODE_LINK_TO
(
group_norm_bias
,
fused_node
);
IR_NODE_LINK_TO
(
fused_node
,
act_out
);
IR_NODE_LINK_TO
(
fused_node
,
elementwise_out
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
return
found_subgraph_count
;
}
void
PrelnGroupNormActFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
"preln_groupnorm_act_fuse_pass"
,
graph
);
int
found_subgraph_count
=
ApplyGNSiluPattern
(
graph
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
preln_elementwise_groupnorm_act_pass
,
paddle
::
framework
::
ir
::
PrelnGroupNormActFusePass
);
REGISTER_PASS_CAPABILITY
(
preln_elementwise_groupnorm_act_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"elementwise_add"
,
1
)
.
EQ
(
"silu"
,
0
)
.
EQ
(
"group_norm"
,
0
));
paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h
0 → 100644
浏览文件 @
591be3bd
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
//
// | |
// elementwise_add fuse | |
// | | -> preln_gn_act
// other op group_norm | |
// | other op
// silu
// |
class
Graph
;
class
PrelnGroupNormActFusePass
:
public
FusePassBase
{
public:
PrelnGroupNormActFusePass
()
{
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
0
,
-
1
})
.
End
();
AddOpCompat
(
OpCompat
(
"group_norm"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Mean"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Variance"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"epsilon"
)
.
IsNumGE
(
0.0
f
)
.
IsNumLE
(
1.0
f
)
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"data_layout"
)
.
IsStringIn
({
"NCHW"
})
.
End
();
AddOpCompat
(
OpCompat
(
"silu"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
();
}
virtual
~
PrelnGroupNormActFusePass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
int
ApplyGNSiluPattern
(
ir
::
Graph
*
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
591be3bd
...
...
@@ -2420,6 +2420,8 @@ USE_TRT_CONVERTER(logsigmoid)
USE_TRT_CONVERTER
(
lookup_table
)
USE_TRT_CONVERTER
(
expand_v2
)
USE_TRT_CONVERTER
(
take_along_axis
)
USE_TRT_CONVERTER
(
skip_groupnorm_act
)
USE_TRT_CONVERTER
(
preln_groupnorm_act
)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER
(
sparse_fc
)
USE_TRT_CONVERTER
(
sparse_multihead_matmul
)
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
591be3bd
...
...
@@ -106,8 +106,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"vit_attention_fuse_pass"
,
//
#if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading.
#else
"trt_skip_layernorm_fuse_pass"
,
//
"preln_skip_layernorm_fuse_pass"
,
//
"trt_skip_layernorm_fuse_pass"
,
//
"preln_skip_layernorm_fuse_pass"
,
//
#endif
"layernorm_shift_partition_fuse_pass"
,
//
"merge_layernorm_fuse_pass"
,
//
...
...
@@ -128,8 +128,13 @@ const std::vector<std::string> kTRTSubgraphPasses({
// "yolo_box_fuse_pass", //
"dense_fc_to_sparse_pass"
,
//
"dense_multihead_matmul_to_sparse_pass"
,
//
"tensorrt_subgraph_pass"
,
//
"conv_bn_fuse_pass"
,
//
#if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading.
#else
"elementwise_groupnorm_act_pass"
,
//
"preln_elementwise_groupnorm_act_pass"
,
//
#endif
"tensorrt_subgraph_pass"
,
//
"conv_bn_fuse_pass"
,
//
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
...
...
paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
浏览文件 @
591be3bd
...
...
@@ -94,6 +94,8 @@ list(
skip_merge_layernorm_op.cc
generic_and_custom_plugin_creater.cc
fused_lookup_tables_op.cc
skip_groupnorm_act_op.cc
preln_groupnorm_act_op.cc
expand_v2_op.cc
)
if
(
${
TENSORRT_MAJOR_VERSION
}
GREATER_EQUAL 7
)
...
...
paddle/fluid/inference/tensorrt/convert/preln_groupnorm_act_op.cc
0 → 100644
浏览文件 @
591be3bd
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h"
#include <vector>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
PrelnGroupnormActOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert a fluid preln_groupnorm_act op to tensorrt "
"preln_groupnorm_act plugin"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
*
input_x
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
).
front
());
auto
*
input_y
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Y"
).
front
());
std
::
vector
<
nvinfer1
::
ITensor
*>
inputs
{
input_x
,
input_y
};
int
groups
=
PADDLE_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"groups"
));
float
epsilon
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
std
::
string
scale_name
=
op_desc
.
Input
(
"Scale"
).
front
();
std
::
string
bias_name
=
op_desc
.
Input
(
"Bias"
).
front
();
// get the presistable var's data
auto
GetWeight
=
[
&
](
const
std
::
string
&
var_name
,
framework
::
DDim
*
dims
)
->
TensorRTEngine
::
Weight
{
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
phi
::
DenseTensor
>
();
(
*
dims
)
=
temp_tensor
->
dims
();
auto
weight
=
engine_
->
GetTrtWeight
(
var_name
,
*
temp_tensor
);
return
weight
;
};
framework
::
DDim
scale_dims
;
framework
::
DDim
bias_dims
;
auto
scale_weights
=
GetWeight
(
scale_name
,
&
scale_dims
);
auto
bias_weights
=
GetWeight
(
bias_name
,
&
bias_dims
);
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
if
(
engine_
->
with_dynamic_shape
())
{
plugin
::
PrelnGroupnormActPluginDynamic
*
plugin
=
new
plugin
::
PrelnGroupnormActPluginDynamic
(
static_cast
<
const
float
*>
(
scale_weights
.
get
().
values
),
scale_weights
.
get
().
count
,
static_cast
<
const
float
*>
(
bias_weights
.
get
().
values
),
bias_weights
.
get
().
count
,
epsilon
,
groups
,
with_fp16
);
nvinfer1
::
ILayer
*
groupnorm_layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
2
,
plugin
);
std
::
vector
<
std
::
string
>
output_names
;
output_names
.
emplace_back
(
op_desc
.
Output
(
"Out_0"
).
front
());
output_names
.
emplace_back
(
op_desc
.
Output
(
"Out_1"
).
front
());
RreplenishLayerAndOutput
(
groupnorm_layer
,
"preln_groupnorm_act"
,
output_names
,
test_mode
);
}
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
preln_groupnorm_act
,
PrelnGroupnormActOpConverter
);
paddle/fluid/inference/tensorrt/convert/skip_groupnorm_act_op.cc
0 → 100644
浏览文件 @
591be3bd
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h"
#include <vector>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
proto
{
class
OpDesc
;
}
// namespace proto
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
SkipGroupnormActOpConverter
:
public
OpConverter
{
public:
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
{
VLOG
(
4
)
<<
"convert a fluid skip_groupnorm_act op to tensorrt "
"skip_groupnorm_act plugin"
;
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
auto
*
inputx
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"X"
).
front
());
auto
*
inputy
=
engine_
->
GetITensor
(
op_desc
.
Input
(
"Y"
).
front
());
std
::
vector
<
nvinfer1
::
ITensor
*>
inputs
{
inputx
,
inputy
};
int
groups
=
PADDLE_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"groups"
));
float
epsilon
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
std
::
string
scale_name
=
op_desc
.
Input
(
"Scale"
).
front
();
std
::
string
bias_name
=
op_desc
.
Input
(
"Bias"
).
front
();
// get the presistable var's data
auto
GetWeight
=
[
&
](
const
std
::
string
&
var_name
,
framework
::
DDim
*
dims
)
->
TensorRTEngine
::
Weight
{
auto
*
temp_var
=
scope
.
FindVar
(
var_name
);
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
phi
::
DenseTensor
>
();
(
*
dims
)
=
temp_tensor
->
dims
();
auto
weight
=
engine_
->
GetTrtWeight
(
var_name
,
*
temp_tensor
);
return
weight
;
};
framework
::
DDim
scale_dims
;
framework
::
DDim
bias_dims
;
auto
scale_weights
=
GetWeight
(
scale_name
,
&
scale_dims
);
auto
bias_weights
=
GetWeight
(
bias_name
,
&
bias_dims
);
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
if
(
engine_
->
with_dynamic_shape
())
{
plugin
::
SkipGroupnormActPluginDynamic
*
plugin
=
new
plugin
::
SkipGroupnormActPluginDynamic
(
static_cast
<
const
float
*>
(
scale_weights
.
get
().
values
),
scale_weights
.
get
().
count
,
static_cast
<
const
float
*>
(
bias_weights
.
get
().
values
),
bias_weights
.
get
().
count
,
epsilon
,
groups
,
with_fp16
);
nvinfer1
::
ILayer
*
groupnorm_layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
2
,
plugin
);
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
groupnorm_layer
,
"skip_groupnorm_act"
,
{
output_name
},
test_mode
);
}
}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
skip_groupnorm_act
,
SkipGroupnormActOpConverter
);
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
591be3bd
...
...
@@ -2390,6 +2390,22 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if
(
op_type
==
"skip_groupnorm_act"
)
{
if
(
!
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"The skip_groupnorm_act op does not support "
"static shape yet"
;
return
false
;
}
}
if
(
op_type
==
"preln_groupnorm_act"
)
{
if
(
!
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"The preln_groupnorm_act op does not support "
"static shape yet"
;
return
false
;
}
}
if
(
op_type
==
"lookup_table"
)
{
if
(
!
with_dynamic_shape
)
{
VLOG
(
3
)
<<
"the lookup_table does not support "
...
...
@@ -2561,7 +2577,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"merge_layernorm"
,
"skip_merge_layernorm"
,
"lookup_table_v2"
,
"expand_v2"
};
"expand_v2"
,
"skip_groupnorm_act"
,
"preln_groupnorm_act"
};
std
::
unordered_set
<
std
::
string
>
teller_set
{
"mul"
,
...
...
@@ -2709,7 +2727,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"skip_merge_layernorm"
,
"lookup_table"
,
"lookup_table_v2"
,
"expand_v2"
};
"expand_v2"
,
"skip_groupnorm_act"
,
"preln_groupnorm_act"
};
};
struct
GenericPluginTeller
:
public
Teller
{
...
...
paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
浏览文件 @
591be3bd
...
...
@@ -35,6 +35,8 @@ list(
prelnlayernorm_shift_partition_op.cu
merge_layernorm_op_plugin.cu
skip_merge_layernorm_op_plugin.cu
skip_groupnorm_act_op_plugin.cu
preln_groupnorm_act_op_plugin.cu
generic_plugin.cu
lookup_table.cu
many_emb_layernorm_plugin.cu
...
...
paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h
浏览文件 @
591be3bd
...
...
@@ -27,6 +27,8 @@ namespace plugin {
struct
GroupNormNHWCParams
{
// The output buffer. Layout NHWC.
__half
*
dst
;
// The output buffer. Layout NHWC.
__half
*
eleOut
;
// The input buffer. Layout NHWC.
__half
const
*
srcX
;
// The input buffer. Layout NHWC.
...
...
paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu
0 → 100644
浏览文件 @
591be3bd
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES.
All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h"
#include <cub/cub.cuh>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
nvinfer1
::
DimsExprs
PrelnGroupnormActPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputDims
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
return
inputDims
[
0
];
}
bool
PrelnGroupnormActPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of prelnGroupnormAct plugin shoule not be nullptr."
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
return
((
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
PluginFormat
::
kHWC8
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"PrelnGroupnormAct TRT Plugin is fp16 only so far"
));
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
pos
-
1
];
// output
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
nvinfer1
::
DataType
PrelnGroupnormActPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
return
input_types
[
0
];
}
int
PrelnGroupnormActPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
static
inline
int32_t
divUp
(
int32_t
m
,
int32_t
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
static
int32_t
findMaxDivisor
(
int32_t
n
,
int32_t
maxAllowedDivisor
)
{
int32_t
maxDivisor
=
-
1
;
for
(
int32_t
i
=
1
;
i
<=
std
::
sqrt
(
n
);
i
++
)
{
if
(
n
%
i
==
0
)
{
int32_t
divisor1
=
n
/
i
;
int32_t
divisor2
=
i
;
if
(
divisor1
>
maxDivisor
&&
divisor1
<
maxAllowedDivisor
)
{
maxDivisor
=
divisor1
;
}
if
(
divisor2
>
maxDivisor
&&
divisor2
<
maxAllowedDivisor
)
{
maxDivisor
=
divisor2
;
}
}
}
return
maxDivisor
;
}
static
inline
__device__
__host__
float
sigmoid
(
float
x
)
{
return
1.
F
/
(
1.
F
+
expf
(
-
x
));
}
struct
GroupSums
{
// Is it the 1st element of the group?
int32_t
flag
;
// The sum.
float
sum
;
// The sum of squares.
float
sumSq
;
};
struct
GroupSumsOp
{
inline
__device__
GroupSums
operator
()(
GroupSums
const
&
a
,
GroupSums
const
&
b
)
{
GroupSums
dst
;
dst
.
sum
=
b
.
flag
?
b
.
sum
:
(
a
.
sum
+
b
.
sum
);
dst
.
sumSq
=
b
.
flag
?
b
.
sumSq
:
(
a
.
sumSq
+
b
.
sumSq
);
dst
.
flag
=
a
.
flag
+
b
.
flag
;
return
dst
;
}
};
template
<
int32_t
tTHREADS_PER_BLOCK
>
__global__
void
prelnGroupNormNHWCSumKernel
(
GroupNormNHWCParams
params
)
{
// The object in charge of doing the sums for the different blocks.
typedef
cub
::
BlockScan
<
GroupSums
,
tTHREADS_PER_BLOCK
>
BlockScan
;
// Allocate shared memory for BlockScan.
__shared__
typename
BlockScan
::
TempStorage
tempStorage
;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__
float2
smem
[
tTHREADS_PER_BLOCK
];
// The instance in the batch.
int32_t
ni
=
blockIdx
.
z
;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t
ci
=
blockIdx
.
x
*
params
.
cPerBlock
+
threadIdx
.
x
*
2
;
// The first activation loaded by that block.
int32_t
hwBegin
=
blockIdx
.
y
*
params
.
hwPerBlock
;
// The last activation loaded by that block.
int32_t
hwEnd
=
min
(
hwBegin
+
params
.
hwPerBlock
,
params
.
hw
);
// The sums.
float
sum
=
0.
F
;
float
sumSq
=
0.
F
;
// Iterate over the activations to compute the sums.
for
(
int32_t
hwi
=
hwBegin
;
hwi
<
hwEnd
;
++
hwi
)
{
// The offset.
int64_t
offset
=
static_cast
<
int64_t
>
(
ni
)
*
params
.
hwc
+
static_cast
<
int64_t
>
(
hwi
)
*
params
.
c
+
ci
;
// Fetch two channels per thread.
__half2
h2
(
0
,
0
);
if
(
ci
<
params
.
c
)
{
// int64_t offsetY = static_cast<int64_t>(ni) * params.c + ci;
__half2
y
=
*
reinterpret_cast
<
__half2
const
*>
(
&
params
.
srcY
[
offset
]);
h2
=
*
reinterpret_cast
<
__half2
const
*>
(
&
params
.
srcX
[
offset
]);
h2
=
__hadd2
(
h2
,
y
);
// elementwise_add
*
reinterpret_cast
<
__half2
*>
(
&
params
.
eleOut
[
offset
])
=
h2
;
}
// Extract the two half values.
float2
f2
=
__half22float2
(
h2
);
// Update the sum.
sum
+=
f2
.
x
+
f2
.
y
;
// Update the sum of squares.
sumSq
+=
f2
.
x
*
f2
.
x
+
f2
.
y
*
f2
.
y
;
}
// The group that thread works on and the channel in the group (modulus).
int32_t
gi
=
threadIdx
.
x
*
2
/
params
.
cPerGroup
;
int32_t
cj
=
threadIdx
.
x
*
2
-
params
.
cPerGroup
*
gi
;
// The data for the summations.
GroupSums
inp
{
cj
==
0
?
1
:
0
,
sum
,
sumSq
};
// Do the segmented scan.
GroupSums
out
;
BlockScan
(
tempStorage
).
InclusiveScan
(
inp
,
out
,
GroupSumsOp
());
// Store the results for the groups in shared memory (to produce coalesced
// stores later).
if
(
cj
==
params
.
cPerGroup
-
2
/* 2 channels per thread */
)
{
smem
[
gi
]
=
make_float2
(
out
.
sum
,
out
.
sumSq
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// The global group index.
int32_t
gj
=
blockIdx
.
x
*
params
.
groupsPerBlock
+
threadIdx
.
x
;
// Threads that have nothing left to do, exit.
if
(
threadIdx
.
x
>=
params
.
groupsPerBlock
||
gj
>=
params
.
groups
)
{
return
;
}
// The first threads (those storing to global memory, load the values).
float2
sums
=
smem
[
threadIdx
.
x
];
// Store to global memory.
atomicAdd
(
&
params
.
redBuffer
[(
2
*
ni
+
0
)
*
params
.
groups
+
gj
],
sums
.
x
);
atomicAdd
(
&
params
.
redBuffer
[(
2
*
ni
+
1
)
*
params
.
groups
+
gj
],
sums
.
y
);
}
void
prelnGroupNormNHWCSum
(
GroupNormNHWCParams
const
&
params
,
cudaStream_t
stream
)
{
// Make sure the values are as we expect.
PADDLE_ENFORCE_EQ
(
params
.
c
%
params
.
cPerBlock
,
0
,
platform
::
errors
::
InvalidArgument
(
"The groupNormNHWCSum of prelnGroupnormAct Plugin got "
"wrong parameters"
"params.c %% params.cPerBlock should be 0, but get %d."
,
params
.
c
%
params
.
cPerBlock
));
PADDLE_ENFORCE_EQ
(
params
.
hw
%
params
.
hwPerBlock
,
0
,
platform
::
errors
::
InvalidArgument
(
"The groupNormNHWCSum of prelnGroupnormAct Plugin got wrong "
"parameters"
"params.hw %% params.hwPerBlock should be 0, but get %d."
,
params
.
hw
%
params
.
hwPerBlock
));
// Make sure a group does not span multiple blocks.
PADDLE_ENFORCE_EQ
(
params
.
cPerBlock
%
params
.
cPerGroup
,
0
,
platform
::
errors
::
InvalidArgument
(
"The groupNormNHWCSum of prelnGroupnormAct Plugin got wrong "
"parameters"
"params.cPerBlock %% params.cPerGroup should be 0, but get %d."
,
params
.
cPerBlock
%
params
.
cPerGroup
));
dim3
grid
;
// The number of blocks to compute all the channels.
grid
.
x
=
params
.
c
/
params
.
cPerBlock
;
// The number of blocks to compute all the activations in a given instance.
grid
.
y
=
divUp
(
params
.
hw
,
params
.
hwPerBlock
);
// The number of instances.
grid
.
z
=
params
.
n
;
switch
(
params
.
cPerBlock
)
{
case
320
:
prelnGroupNormNHWCSumKernel
<
160
><<<
grid
,
160
,
0
,
stream
>>>
(
params
);
break
;
case
480
:
prelnGroupNormNHWCSumKernel
<
256
><<<
grid
,
256
,
0
,
stream
>>>
(
params
);
break
;
case
256
:
prelnGroupNormNHWCSumKernel
<
128
><<<
grid
,
128
,
0
,
stream
>>>
(
params
);
break
;
case
128
:
prelnGroupNormNHWCSumKernel
<
64
><<<
grid
,
64
,
0
,
stream
>>>
(
params
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin "
"encounter error"
));
}
}
template
<
int32_t
tTHREADS_PER_BLOCK
>
__global__
void
prelnGroupNormNHWCScaleKernel
(
GroupNormNHWCParams
params
)
{
// The instance in the batch.
int32_t
ni
=
blockIdx
.
z
;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t
ci
=
blockIdx
.
x
*
params
.
cPerBlock
+
threadIdx
.
x
*
2
;
// The group that thread works on and the channel in the group (modulus).
int32_t
gi
=
ci
/
params
.
cPerGroup
;
// Load the sum and sum of squares for the group.
float
sum
=
0.
F
,
sumSq
=
0.
F
;
if
(
gi
<
params
.
groups
)
{
sum
=
params
.
redBuffer
[(
2
*
ni
+
0
)
*
params
.
groups
+
gi
];
sumSq
=
params
.
redBuffer
[(
2
*
ni
+
1
)
*
params
.
groups
+
gi
];
}
// Load gamma/beta.
float2
gammaF2
,
betaF2
;
if
(
ci
<
params
.
c
)
{
gammaF2
=
*
reinterpret_cast
<
float2
const
*>
(
reinterpret_cast
<
float
const
*>
(
params
.
gamma
)
+
ci
);
betaF2
=
*
reinterpret_cast
<
float2
const
*>
(
reinterpret_cast
<
float
const
*>
(
params
.
beta
)
+
ci
);
}
// Compute the mean.
float
mean
=
sum
*
params
.
invHWC
;
// Compute the variance.
float
var
=
sumSq
*
params
.
invHWC
-
(
mean
*
mean
);
// Compute the inverse of the stddev.
float
invStdDev
=
rsqrtf
(
var
+
params
.
eps
);
// The first activation loaded by that block.
int32_t
hwBegin
=
blockIdx
.
y
*
params
.
hwPerBlock
;
// The last activation loaded by that block.
int32_t
hwEnd
=
min
(
hwBegin
+
params
.
hwPerBlock
,
params
.
hw
);
// Iterate over the activations to compute the sums.
for
(
int32_t
hwi
=
hwBegin
;
hwi
<
hwEnd
;
++
hwi
)
{
// The src/dst offset.
int64_t
offset
=
(
int64_t
)
ni
*
params
.
hwc
+
hwi
*
params
.
c
+
ci
;
// Fetch two channels per thread.
__half2
h2
(
0
,
0
);
if
(
ci
<
params
.
c
)
{
h2
=
*
reinterpret_cast
<
__half2
const
*>
(
&
params
.
eleOut
[
offset
]);
}
// Extract the two half values.
float2
f2
=
__half22float2
(
h2
);
// Normalize the channels.
f2
.
x
=
(
f2
.
x
-
mean
)
*
invStdDev
;
f2
.
y
=
(
f2
.
y
-
mean
)
*
invStdDev
;
// Scale by gamma and add beta.
f2
.
x
=
gammaF2
.
x
*
f2
.
x
+
betaF2
.
x
;
f2
.
y
=
gammaF2
.
y
*
f2
.
y
+
betaF2
.
y
;
// Apply Swish if needed.
if
(
params
.
withSwish
)
{
f2
.
x
=
f2
.
x
*
sigmoid
(
f2
.
x
);
f2
.
y
=
f2
.
y
*
sigmoid
(
f2
.
y
);
}
// Store the scaled values.
if
(
ci
<
params
.
c
)
{
*
reinterpret_cast
<
__half2
*>
(
&
params
.
dst
[
offset
])
=
__float22half2_rn
(
f2
);
}
}
}
void
prelnGroupNormNHWCScale
(
GroupNormNHWCParams
const
&
params
,
cudaStream_t
stream
)
{
// Make sure the dimensions are aligned with what we expect.
PADDLE_ENFORCE_EQ
(
params
.
c
%
params
.
cPerBlock
,
0
,
platform
::
errors
::
InvalidArgument
(
"The groupNormNHWCScale of prelnGroupnormAct Plugin got "
"wrong parameters"
"params.c %% params.cPerBlock should be 0, but get %d."
,
params
.
c
%
params
.
cPerBlock
));
// Make sure a group does not span multiple blocks.
PADDLE_ENFORCE_EQ
(
params
.
cPerBlock
%
params
.
cPerGroup
,
0
,
platform
::
errors
::
InvalidArgument
(
"The groupNormNHWCScale of prelnGroupnormAct Plugin got wrong "
"parameters"
"params.cPerBlock %% params.cPerGroup should be 0, but get %d."
,
params
.
cPerBlock
%
params
.
cPerGroup
));
dim3
grid
;
// The number of blocks to compute all the channels.
grid
.
x
=
params
.
c
/
params
.
cPerBlock
;
// The number of blocks to compute all the activations in a given instance.
grid
.
y
=
divUp
(
params
.
hw
,
params
.
hwPerBlock
);
// The number of instances.
grid
.
z
=
params
.
n
;
switch
(
params
.
cPerBlock
)
{
case
320
:
prelnGroupNormNHWCScaleKernel
<
160
><<<
grid
,
160
,
0
,
stream
>>>
(
params
);
break
;
case
480
:
prelnGroupNormNHWCScaleKernel
<
256
><<<
grid
,
256
,
0
,
stream
>>>
(
params
);
break
;
case
256
:
prelnGroupNormNHWCScaleKernel
<
128
><<<
grid
,
128
,
0
,
stream
>>>
(
params
);
break
;
case
128
:
prelnGroupNormNHWCScaleKernel
<
64
><<<
grid
,
64
,
0
,
stream
>>>
(
params
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin "
"encounter error"
));
}
}
int
PrelnGroupnormActPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
input_type
=
input_desc
[
0
].
type
;
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. prelnGroupnormAct-->fp32"
;
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The prelnGroupnormAct TRT Plugin's only support fp16 input"
));
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. prelnGroupnormAct-->fp16"
;
int32_t
cPerBlock
=
320
;
int32_t
maxBlocksPerHW
=
1024
;
switch
(
input_desc
[
0
].
dims
.
d
[
1
])
{
case
960
:
case
1920
:
cPerBlock
=
480
;
break
;
case
512
:
case
256
:
cPerBlock
=
256
;
break
;
case
128
:
cPerBlock
=
128
;
break
;
default:
cPerBlock
=
320
;
}
params_
.
withSwish
=
true
;
params_
.
dst
=
static_cast
<
half
*>
(
outputs
[
1
]);
params_
.
eleOut
=
static_cast
<
half
*>
(
outputs
[
0
]);
params_
.
srcX
=
static_cast
<
half
const
*>
(
inputs
[
0
]);
params_
.
srcY
=
static_cast
<
half
const
*>
(
inputs
[
1
]);
params_
.
gamma
=
scale_gpu_
.
get
();
params_
.
beta
=
bias_gpu_
.
get
();
params_
.
redBuffer
=
static_cast
<
float
*>
(
workspace
);
params_
.
n
=
input_desc
[
0
].
dims
.
d
[
0
];
params_
.
h
=
input_desc
[
0
].
dims
.
d
[
2
];
params_
.
w
=
input_desc
[
0
].
dims
.
d
[
3
];
params_
.
c
=
input_desc
[
0
].
dims
.
d
[
1
];
params_
.
groups
=
groups_
;
params_
.
hw
=
params_
.
h
*
params_
.
w
;
const
int32_t
blocksPerHW
=
findMaxDivisor
(
params_
.
hw
,
maxBlocksPerHW
);
params_
.
hwPerBlock
=
divUp
(
params_
.
hw
,
blocksPerHW
);
params_
.
cPerBlock
=
cPerBlock
;
params_
.
cPerGroup
=
params_
.
c
/
params_
.
groups
;
params_
.
hwc
=
params_
.
hw
*
params_
.
c
;
params_
.
invHWC
=
1.
F
/
static_cast
<
float
>
(
params_
.
hw
*
params_
.
cPerGroup
);
params_
.
groupsPerBlock
=
cPerBlock
/
params_
.
cPerGroup
;
params_
.
eps
=
eps_
;
cudaMemsetAsync
(
params_
.
redBuffer
,
0
,
ws_
,
stream
);
prelnGroupNormNHWCSum
(
params_
,
stream
);
prelnGroupNormNHWCScale
(
params_
,
stream
);
}
else
{
// input not fp16
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The PrelnGroupnormAct TRT Plugin's only support fp16 input"
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h
0 → 100644
浏览文件 @
591be3bd
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
PrelnGroupnormActPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
PrelnGroupnormActPluginDynamic
(
const
float
*
scale
,
const
int
scale_num
,
const
float
*
bias
,
const
int
bias_num
,
float
eps
,
int
groups
,
bool
with_fp16
,
std
::
shared_ptr
<
void
>
scale_gpu
=
nullptr
,
std
::
shared_ptr
<
void
>
bias_gpu
=
nullptr
)
:
scale_gpu_
(
scale_gpu
),
bias_gpu_
(
bias_gpu
),
groups_
(
groups
),
eps_
(
eps
),
with_fp16_
(
with_fp16
)
{
scale_
.
resize
(
scale_num
);
bias_
.
resize
(
bias_num
);
std
::
copy
(
scale
,
scale
+
scale_num
,
scale_
.
data
());
std
::
copy
(
bias
,
bias
+
bias_num
,
bias_
.
data
());
if
(
scale_gpu_
==
nullptr
)
{
void
*
p
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
p
),
scale_num
*
sizeof
(
float
));
scale_gpu_
.
reset
(
p
,
[](
void
*
ptr
)
{
cudaFree
(
ptr
);
});
cudaMemcpy
(
p
,
scale_
.
data
(),
scale_num
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
if
(
bias_gpu_
==
nullptr
)
{
void
*
p
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
p
),
bias_num
*
sizeof
(
float
));
bias_gpu_
.
reset
(
p
,
[](
void
*
ptr
)
{
cudaFree
(
ptr
);
});
cudaMemcpy
(
p
,
bias_
.
data
(),
bias_num
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
}
PrelnGroupnormActPluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
)
{
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
scale_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
bias_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
eps_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
groups_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
with_fp16_
);
{
void
*
p
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
p
),
scale_
.
size
()
*
sizeof
(
float
));
scale_gpu_
.
reset
(
p
,
[](
void
*
ptr
)
{
cudaFree
(
ptr
);
});
cudaMemcpy
(
p
,
scale_
.
data
(),
scale_
.
size
()
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
{
void
*
p
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
p
),
bias_
.
size
()
*
sizeof
(
float
));
bias_gpu_
.
reset
(
p
,
[](
void
*
ptr
)
{
cudaFree
(
ptr
);
});
cudaMemcpy
(
p
,
bias_
.
data
(),
bias_
.
size
()
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
auto
*
ptr
=
new
PrelnGroupnormActPluginDynamic
(
scale_
.
data
(),
scale_
.
size
(),
bias_
.
data
(),
bias_
.
size
(),
eps_
,
groups_
,
with_fp16_
,
scale_gpu_
,
bias_gpu_
);
return
ptr
;
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"preln_groupnorm_act_plugin_dynamic"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
2
;
}
int
initialize
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
return
SerializedSize
(
scale_
)
+
SerializedSize
(
bias_
)
+
SerializedSize
(
eps_
)
+
SerializedSize
(
groups_
)
+
SerializedSize
(
with_fp16_
);
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{
SerializeValue
(
&
buffer
,
scale_
);
SerializeValue
(
&
buffer
,
bias_
);
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
groups_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
// NOLINT
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
{
// sizeof(float2) * maxBatchSize * maxNumberOfGroup. float2
// contians two buffers for sum and squared sum;
ws_
=
sizeof
(
float
)
*
2
*
in
[
0
].
max
.
d
[
0
]
*
groups_
;
}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
{
return
ws_
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
void
terminate
()
TRT_NOEXCEPT
override
{};
private:
size_t
ws_
;
std
::
vector
<
float
>
scale_
;
std
::
vector
<
float
>
bias_
;
std
::
shared_ptr
<
void
>
scale_gpu_
;
std
::
shared_ptr
<
void
>
bias_gpu_
;
GroupNormNHWCParams
params_
;
int
groups_
;
float
eps_
;
bool
with_fp16_
;
};
class
PrelnGroupnormActPluginDynamicCreator
:
public
TensorRTPluginCreator
{
public:
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"preln_groupnorm_act_plugin_dynamic"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
return
new
PrelnGroupnormActPluginDynamic
(
serial_data
,
serial_length
);
}
};
REGISTER_TRT_PLUGIN_V2
(
PrelnGroupnormActPluginDynamicCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu
0 → 100644
浏览文件 @
591be3bd
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES.
All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h"
#include <cub/cub.cuh>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
nvinfer1
::
DimsExprs
SkipGroupnormActPluginDynamic
::
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputDims
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
TRT_NOEXCEPT
{
return
inputDims
[
0
];
}
bool
SkipGroupnormActPluginDynamic
::
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
in_out
,
int
nb_inputs
,
int
nb_outputs
)
TRT_NOEXCEPT
{
PADDLE_ENFORCE_NOT_NULL
(
in_out
,
platform
::
errors
::
InvalidArgument
(
"The input of SkipGroupnormAct plugin shoule not be nullptr."
));
PADDLE_ENFORCE_LT
(
pos
,
nb_inputs
+
nb_outputs
,
platform
::
errors
::
InvalidArgument
(
"The pos(%d) should be less than the "
"num(%d) of the input and the output."
,
pos
,
nb_inputs
+
nb_outputs
));
const
nvinfer1
::
PluginTensorDesc
&
in
=
in_out
[
pos
];
if
(
pos
==
0
)
{
if
(
with_fp16_
)
{
return
((
in
.
type
==
nvinfer1
::
DataType
::
kHALF
)
&&
(
in
.
format
==
nvinfer1
::
PluginFormat
::
kHWC8
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"SkipGroupnormAct TRT Plugin is fp16 only so far"
));
return
(
in
.
type
==
nvinfer1
::
DataType
::
kFLOAT
)
&&
(
in
.
format
==
nvinfer1
::
TensorFormat
::
kLINEAR
);
}
}
const
nvinfer1
::
PluginTensorDesc
&
prev
=
in_out
[
pos
-
1
];
// output
return
in
.
type
==
prev
.
type
&&
in
.
format
==
prev
.
format
;
}
nvinfer1
::
DataType
SkipGroupnormActPluginDynamic
::
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
input_types
,
int
nb_inputs
)
const
TRT_NOEXCEPT
{
PADDLE_ENFORCE_EQ
(
index
,
0
,
platform
::
errors
::
InvalidArgument
(
"The SkipGroupnormAct Plugin only has one input, so the "
"index value should be 0, but get %d."
,
index
));
PADDLE_ENFORCE_EQ
((
input_types
[
0
]
==
nvinfer1
::
DataType
::
kFLOAT
||
input_types
[
0
]
==
nvinfer1
::
DataType
::
kHALF
),
true
,
platform
::
errors
::
InvalidArgument
(
"The input type should be half or float"
));
return
input_types
[
0
];
}
int
SkipGroupnormActPluginDynamic
::
initialize
()
TRT_NOEXCEPT
{
return
0
;
}
static
inline
int32_t
divUp
(
int32_t
m
,
int32_t
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
static
int32_t
findMaxDivisor
(
int32_t
n
,
int32_t
maxAllowedDivisor
)
{
int32_t
maxDivisor
=
-
1
;
for
(
int32_t
i
=
1
;
i
<=
std
::
sqrt
(
n
);
i
++
)
{
if
(
n
%
i
==
0
)
{
int32_t
divisor1
=
n
/
i
;
int32_t
divisor2
=
i
;
if
(
divisor1
>
maxDivisor
&&
divisor1
<
maxAllowedDivisor
)
{
maxDivisor
=
divisor1
;
}
if
(
divisor2
>
maxDivisor
&&
divisor2
<
maxAllowedDivisor
)
{
maxDivisor
=
divisor2
;
}
}
}
return
maxDivisor
;
}
static
inline
__device__
__host__
float
sigmoid
(
float
x
)
{
return
1.
F
/
(
1.
F
+
expf
(
-
x
));
}
struct
GroupSums
{
// Is it the 1st element of the group?
int32_t
flag
;
// The sum.
float
sum
;
// The sum of squares.
float
sumSq
;
};
struct
GroupSumsOp
{
inline
__device__
GroupSums
operator
()(
GroupSums
const
&
a
,
GroupSums
const
&
b
)
{
GroupSums
dst
;
dst
.
sum
=
b
.
flag
?
b
.
sum
:
(
a
.
sum
+
b
.
sum
);
dst
.
sumSq
=
b
.
flag
?
b
.
sumSq
:
(
a
.
sumSq
+
b
.
sumSq
);
dst
.
flag
=
a
.
flag
+
b
.
flag
;
return
dst
;
}
};
template
<
int32_t
tTHREADS_PER_BLOCK
>
__global__
void
skipGroupNormNHWCSumKernel
(
GroupNormNHWCParams
params
)
{
// The object in charge of doing the sums for the different blocks.
typedef
cub
::
BlockScan
<
GroupSums
,
tTHREADS_PER_BLOCK
>
BlockScan
;
// Allocate shared memory for BlockScan.
__shared__
typename
BlockScan
::
TempStorage
tempStorage
;
// Allocate shared memory for the groups. We could reduce the amount of shared
// memory reserved.
__shared__
float2
smem
[
tTHREADS_PER_BLOCK
];
// The instance in the batch.
int32_t
ni
=
blockIdx
.
z
;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t
ci
=
blockIdx
.
x
*
params
.
cPerBlock
+
threadIdx
.
x
*
2
;
// The first activation loaded by that block.
int32_t
hwBegin
=
blockIdx
.
y
*
params
.
hwPerBlock
;
// The last activation loaded by that block.
int32_t
hwEnd
=
min
(
hwBegin
+
params
.
hwPerBlock
,
params
.
hw
);
// The sums.
float
sum
=
0.
F
;
float
sumSq
=
0.
F
;
// Iterate over the activations to compute the sums.
for
(
int32_t
hwi
=
hwBegin
;
hwi
<
hwEnd
;
++
hwi
)
{
// The offset.
int64_t
offset
=
static_cast
<
int64_t
>
(
ni
)
*
params
.
hwc
+
static_cast
<
int64_t
>
(
hwi
)
*
params
.
c
+
ci
;
// Fetch two channels per thread.
__half2
h2
(
0
,
0
);
if
(
ci
<
params
.
c
)
{
// W = 1, H = 1
int64_t
offsetY
=
static_cast
<
int64_t
>
(
ni
)
*
params
.
c
+
ci
;
__half2
y
=
*
reinterpret_cast
<
__half2
const
*>
(
&
params
.
srcY
[
offsetY
]);
h2
=
*
reinterpret_cast
<
__half2
const
*>
(
&
params
.
srcX
[
offset
]);
h2
=
__hadd2
(
h2
,
y
);
// elementwise_add
*
reinterpret_cast
<
__half2
*>
(
&
params
.
dst
[
offset
])
=
h2
;
}
// Extract the two half values.
float2
f2
=
__half22float2
(
h2
);
// Update the sum.
sum
+=
f2
.
x
+
f2
.
y
;
// Update the sum of squares.
sumSq
+=
f2
.
x
*
f2
.
x
+
f2
.
y
*
f2
.
y
;
}
// The group that thread works on and the channel in the group (modulus).
int32_t
gi
=
threadIdx
.
x
*
2
/
params
.
cPerGroup
;
int32_t
cj
=
threadIdx
.
x
*
2
-
params
.
cPerGroup
*
gi
;
// The data for the summations.
GroupSums
inp
{
cj
==
0
?
1
:
0
,
sum
,
sumSq
};
// Do the segmented scan.
GroupSums
out
;
BlockScan
(
tempStorage
).
InclusiveScan
(
inp
,
out
,
GroupSumsOp
());
// Store the results for the groups in shared memory (to produce coalesced
// stores later).
if
(
cj
==
params
.
cPerGroup
-
2
/* 2 channels per thread */
)
{
smem
[
gi
]
=
make_float2
(
out
.
sum
,
out
.
sumSq
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// The global group index.
int32_t
gj
=
blockIdx
.
x
*
params
.
groupsPerBlock
+
threadIdx
.
x
;
// Threads that have nothing left to do, exit.
if
(
threadIdx
.
x
>=
params
.
groupsPerBlock
||
gj
>=
params
.
groups
)
{
return
;
}
// The first threads (those storing to global memory, load the values).
float2
sums
=
smem
[
threadIdx
.
x
];
// Store to global memory.
atomicAdd
(
&
params
.
redBuffer
[(
2
*
ni
+
0
)
*
params
.
groups
+
gj
],
sums
.
x
);
atomicAdd
(
&
params
.
redBuffer
[(
2
*
ni
+
1
)
*
params
.
groups
+
gj
],
sums
.
y
);
}
void
skipGroupNormNHWCSum
(
GroupNormNHWCParams
const
&
params
,
cudaStream_t
stream
)
{
// Make sure the values are as we expect.
PADDLE_ENFORCE_EQ
(
params
.
c
%
params
.
cPerBlock
,
0
,
platform
::
errors
::
InvalidArgument
(
"The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters"
"params.c %% params.cPerBlock should be 0, but get %d."
,
params
.
c
%
params
.
cPerBlock
));
PADDLE_ENFORCE_EQ
(
params
.
hw
%
params
.
hwPerBlock
,
0
,
platform
::
errors
::
InvalidArgument
(
"The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters"
"params.hw %% params.hwPerBlock should be 0, but get %d."
,
params
.
hw
%
params
.
hwPerBlock
));
// Make sure a group does not span multiple blocks.
PADDLE_ENFORCE_EQ
(
params
.
cPerBlock
%
params
.
cPerGroup
,
0
,
platform
::
errors
::
InvalidArgument
(
"The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters"
"params.cPerBlock %% params.cPerGroup should be 0, but get %d."
,
params
.
cPerBlock
%
params
.
cPerGroup
));
dim3
grid
;
// The number of blocks to compute all the channels.
grid
.
x
=
params
.
c
/
params
.
cPerBlock
;
// The number of blocks to compute all the activations in a given instance.
grid
.
y
=
divUp
(
params
.
hw
,
params
.
hwPerBlock
);
// The number of instances.
grid
.
z
=
params
.
n
;
switch
(
params
.
cPerBlock
)
{
case
320
:
skipGroupNormNHWCSumKernel
<
160
><<<
grid
,
160
,
0
,
stream
>>>
(
params
);
break
;
case
480
:
skipGroupNormNHWCSumKernel
<
256
><<<
grid
,
256
,
0
,
stream
>>>
(
params
);
break
;
case
256
:
skipGroupNormNHWCSumKernel
<
128
><<<
grid
,
128
,
0
,
stream
>>>
(
params
);
break
;
case
128
:
skipGroupNormNHWCSumKernel
<
64
><<<
grid
,
64
,
0
,
stream
>>>
(
params
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin "
"encounter error"
));
}
}
template
<
int32_t
tTHREADS_PER_BLOCK
>
__global__
void
skipGroupNormNHWCScaleKernel
(
GroupNormNHWCParams
params
)
{
// The instance in the batch.
int32_t
ni
=
blockIdx
.
z
;
// The channel loaded by that thread (2 channels per thread for F16x2).
int32_t
ci
=
blockIdx
.
x
*
params
.
cPerBlock
+
threadIdx
.
x
*
2
;
// The group that thread works on and the channel in the group (modulus).
int32_t
gi
=
ci
/
params
.
cPerGroup
;
// Load the sum and sum of squares for the group.
float
sum
=
0.
F
,
sumSq
=
0.
F
;
if
(
gi
<
params
.
groups
)
{
sum
=
params
.
redBuffer
[(
2
*
ni
+
0
)
*
params
.
groups
+
gi
];
sumSq
=
params
.
redBuffer
[(
2
*
ni
+
1
)
*
params
.
groups
+
gi
];
}
// Load gamma/beta.
float2
gammaF2
,
betaF2
;
if
(
ci
<
params
.
c
)
{
gammaF2
=
*
reinterpret_cast
<
float2
const
*>
(
reinterpret_cast
<
float
const
*>
(
params
.
gamma
)
+
ci
);
betaF2
=
*
reinterpret_cast
<
float2
const
*>
(
reinterpret_cast
<
float
const
*>
(
params
.
beta
)
+
ci
);
}
// Compute the mean.
float
mean
=
sum
*
params
.
invHWC
;
// Compute the variance.
float
var
=
sumSq
*
params
.
invHWC
-
(
mean
*
mean
);
// Compute the inverse of the stddev.
float
invStdDev
=
rsqrtf
(
var
+
params
.
eps
);
// The first activation loaded by that block.
int32_t
hwBegin
=
blockIdx
.
y
*
params
.
hwPerBlock
;
// The last activation loaded by that block.
int32_t
hwEnd
=
min
(
hwBegin
+
params
.
hwPerBlock
,
params
.
hw
);
// Iterate over the activations to compute the sums.
for
(
int32_t
hwi
=
hwBegin
;
hwi
<
hwEnd
;
++
hwi
)
{
// The src/dst offset.
int64_t
offset
=
(
int64_t
)
ni
*
params
.
hwc
+
hwi
*
params
.
c
+
ci
;
// Fetch two channels per thread.
__half2
h2
(
0
,
0
);
if
(
ci
<
params
.
c
)
{
h2
=
*
reinterpret_cast
<
__half2
const
*>
(
&
params
.
dst
[
offset
]);
}
// Extract the two half values.
float2
f2
=
__half22float2
(
h2
);
// Normalize the channels.
f2
.
x
=
(
f2
.
x
-
mean
)
*
invStdDev
;
f2
.
y
=
(
f2
.
y
-
mean
)
*
invStdDev
;
// Scale by gamma and add beta.
f2
.
x
=
gammaF2
.
x
*
f2
.
x
+
betaF2
.
x
;
f2
.
y
=
gammaF2
.
y
*
f2
.
y
+
betaF2
.
y
;
// Apply Swish if needed.
if
(
params
.
withSwish
)
{
f2
.
x
=
f2
.
x
*
sigmoid
(
f2
.
x
);
f2
.
y
=
f2
.
y
*
sigmoid
(
f2
.
y
);
}
// Store the scaled values.
if
(
ci
<
params
.
c
)
{
*
reinterpret_cast
<
__half2
*>
(
&
params
.
dst
[
offset
])
=
__float22half2_rn
(
f2
);
}
}
}
void
skipGroupNormNHWCScale
(
GroupNormNHWCParams
const
&
params
,
cudaStream_t
stream
)
{
// Make sure the dimensions are aligned with what we expect.
PADDLE_ENFORCE_EQ
(
params
.
c
%
params
.
cPerBlock
,
0
,
platform
::
errors
::
InvalidArgument
(
"The groupNormNHWCScale of SkipGroupnormAct Plugin got "
"wrong parameters"
"params.c %% params.cPerBlock should be 0, but get %d."
,
params
.
c
%
params
.
cPerBlock
));
// Make sure a group does not span multiple blocks.
PADDLE_ENFORCE_EQ
(
params
.
cPerBlock
%
params
.
cPerGroup
,
0
,
platform
::
errors
::
InvalidArgument
(
"The groupNormNHWCScale of SkipGroupnormAct Plugin got wrong "
"parameters"
"params.cPerBlock %% params.cPerGroup should be 0, but get %d."
,
params
.
cPerBlock
%
params
.
cPerGroup
));
dim3
grid
;
// The number of blocks to compute all the channels.
grid
.
x
=
params
.
c
/
params
.
cPerBlock
;
// The number of blocks to compute all the activations in a given instance.
grid
.
y
=
divUp
(
params
.
hw
,
params
.
hwPerBlock
);
// The number of instances.
grid
.
z
=
params
.
n
;
switch
(
params
.
cPerBlock
)
{
case
320
:
skipGroupNormNHWCScaleKernel
<
160
><<<
grid
,
160
,
0
,
stream
>>>
(
params
);
break
;
case
480
:
skipGroupNormNHWCScaleKernel
<
256
><<<
grid
,
256
,
0
,
stream
>>>
(
params
);
break
;
case
256
:
skipGroupNormNHWCScaleKernel
<
128
><<<
grid
,
128
,
0
,
stream
>>>
(
params
);
break
;
case
128
:
skipGroupNormNHWCScaleKernel
<
64
><<<
grid
,
64
,
0
,
stream
>>>
(
params
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin "
"encounter error"
));
}
}
int
SkipGroupnormActPluginDynamic
::
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
input_desc
,
const
nvinfer1
::
PluginTensorDesc
*
output_desc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
{
auto
input_type
=
input_desc
[
0
].
type
;
if
(
input_type
==
nvinfer1
::
DataType
::
kFLOAT
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. SkipGroupnormAct-->fp32"
;
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The SkipGroupnormAct TRT Plugin's only support fp16 input"
));
}
else
if
(
input_type
==
nvinfer1
::
DataType
::
kHALF
)
{
VLOG
(
1
)
<<
"TRT Plugin DataType selected. SkipGroupnormAct-->fp16"
;
int32_t
cPerBlock
=
320
;
int32_t
maxBlocksPerHW
=
1024
;
switch
(
input_desc
[
0
].
dims
.
d
[
1
])
{
case
960
:
case
1920
:
cPerBlock
=
480
;
break
;
case
512
:
case
256
:
cPerBlock
=
256
;
break
;
case
128
:
cPerBlock
=
128
;
break
;
default:
cPerBlock
=
320
;
}
params_
.
withSwish
=
true
;
params_
.
dst
=
static_cast
<
half
*>
(
outputs
[
0
]);
params_
.
srcX
=
static_cast
<
half
const
*>
(
inputs
[
0
]);
params_
.
srcY
=
static_cast
<
half
const
*>
(
inputs
[
1
]);
params_
.
gamma
=
scale_gpu_
.
get
();
params_
.
beta
=
bias_gpu_
.
get
();
params_
.
redBuffer
=
static_cast
<
float
*>
(
workspace
);
params_
.
n
=
input_desc
[
0
].
dims
.
d
[
0
];
params_
.
h
=
input_desc
[
0
].
dims
.
d
[
2
];
params_
.
w
=
input_desc
[
0
].
dims
.
d
[
3
];
params_
.
c
=
input_desc
[
0
].
dims
.
d
[
1
];
params_
.
groups
=
groups_
;
params_
.
hw
=
params_
.
h
*
params_
.
w
;
const
int32_t
blocksPerHW
=
findMaxDivisor
(
params_
.
hw
,
maxBlocksPerHW
);
params_
.
hwPerBlock
=
divUp
(
params_
.
hw
,
blocksPerHW
);
params_
.
cPerBlock
=
cPerBlock
;
params_
.
cPerGroup
=
params_
.
c
/
params_
.
groups
;
params_
.
hwc
=
params_
.
hw
*
params_
.
c
;
params_
.
invHWC
=
1.
F
/
static_cast
<
float
>
(
params_
.
hw
*
params_
.
cPerGroup
);
params_
.
groupsPerBlock
=
cPerBlock
/
params_
.
cPerGroup
;
params_
.
eps
=
eps_
;
cudaMemsetAsync
(
params_
.
redBuffer
,
0
,
ws_
,
stream
);
skipGroupNormNHWCSum
(
params_
,
stream
);
skipGroupNormNHWCScale
(
params_
,
stream
);
}
else
{
// input not fp16
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"The SkipGroupnormAct TRT Plugin's only support fp16 input"
));
}
return
cudaGetLastError
()
!=
cudaSuccess
;
}
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h
0 → 100644
浏览文件 @
591be3bd
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
plugin
{
class
SkipGroupnormActPluginDynamic
:
public
DynamicPluginTensorRT
{
public:
SkipGroupnormActPluginDynamic
(
const
float
*
scale
,
const
int
scale_num
,
const
float
*
bias
,
const
int
bias_num
,
float
eps
,
int
groups
,
bool
with_fp16
,
std
::
shared_ptr
<
void
>
scale_gpu
=
nullptr
,
std
::
shared_ptr
<
void
>
bias_gpu
=
nullptr
)
:
scale_gpu_
(
scale_gpu
),
bias_gpu_
(
bias_gpu
),
groups_
(
groups
),
eps_
(
eps
),
with_fp16_
(
with_fp16
)
{
scale_
.
resize
(
scale_num
);
bias_
.
resize
(
bias_num
);
std
::
copy
(
scale
,
scale
+
scale_num
,
scale_
.
data
());
std
::
copy
(
bias
,
bias
+
bias_num
,
bias_
.
data
());
if
(
scale_gpu_
==
nullptr
)
{
void
*
p
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
p
),
scale_num
*
sizeof
(
float
));
scale_gpu_
.
reset
(
p
,
[](
void
*
ptr
)
{
cudaFree
(
ptr
);
});
cudaMemcpy
(
p
,
scale_
.
data
(),
scale_num
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
if
(
bias_gpu_
==
nullptr
)
{
void
*
p
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
p
),
bias_num
*
sizeof
(
float
));
bias_gpu_
.
reset
(
p
,
[](
void
*
ptr
)
{
cudaFree
(
ptr
);
});
cudaMemcpy
(
p
,
bias_
.
data
(),
bias_num
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
}
SkipGroupnormActPluginDynamic
(
void
const
*
serialData
,
size_t
serialLength
)
{
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
scale_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
bias_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
eps_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
groups_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
with_fp16_
);
{
void
*
p
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
p
),
scale_
.
size
()
*
sizeof
(
float
));
scale_gpu_
.
reset
(
p
,
[](
void
*
ptr
)
{
cudaFree
(
ptr
);
});
cudaMemcpy
(
p
,
scale_
.
data
(),
scale_
.
size
()
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
{
void
*
p
;
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
p
),
bias_
.
size
()
*
sizeof
(
float
));
bias_gpu_
.
reset
(
p
,
[](
void
*
ptr
)
{
cudaFree
(
ptr
);
});
cudaMemcpy
(
p
,
bias_
.
data
(),
bias_
.
size
()
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
}
nvinfer1
::
IPluginV2DynamicExt
*
clone
()
const
TRT_NOEXCEPT
override
{
auto
*
ptr
=
new
SkipGroupnormActPluginDynamic
(
scale_
.
data
(),
scale_
.
size
(),
bias_
.
data
(),
bias_
.
size
(),
eps_
,
groups_
,
with_fp16_
,
scale_gpu_
,
bias_gpu_
);
return
ptr
;
}
const
char
*
getPluginType
()
const
TRT_NOEXCEPT
override
{
return
"skip_groupnorm_act_plugin_dynamic"
;
}
int
getNbOutputs
()
const
TRT_NOEXCEPT
override
{
return
1
;
}
int
initialize
()
TRT_NOEXCEPT
override
;
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
return
SerializedSize
(
scale_
)
+
SerializedSize
(
bias_
)
+
SerializedSize
(
eps_
)
+
SerializedSize
(
groups_
)
+
SerializedSize
(
with_fp16_
);
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{
SerializeValue
(
&
buffer
,
scale_
);
SerializeValue
(
&
buffer
,
bias_
);
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
groups_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
nvinfer1
::
DimsExprs
getOutputDimensions
(
int
output_index
,
const
nvinfer1
::
DimsExprs
*
inputs
,
int
nb_inputs
,
nvinfer1
::
IExprBuilder
&
expr_builder
)
// NOLINT
TRT_NOEXCEPT
override
;
bool
supportsFormatCombination
(
int
pos
,
const
nvinfer1
::
PluginTensorDesc
*
inOut
,
int
nbInputs
,
int
nbOutputs
)
TRT_NOEXCEPT
override
;
void
configurePlugin
(
const
nvinfer1
::
DynamicPluginTensorDesc
*
in
,
int
nbInputs
,
const
nvinfer1
::
DynamicPluginTensorDesc
*
out
,
int
nbOutputs
)
TRT_NOEXCEPT
override
{
// sizeof(float2) * maxBatchSize * maxNumberOfGroup. float2
// contians two buffers for sum and squared sum;
ws_
=
sizeof
(
float
)
*
2
*
in
[
0
].
max
.
d
[
0
]
*
groups_
;
}
size_t
getWorkspaceSize
(
const
nvinfer1
::
PluginTensorDesc
*
inputs
,
int
nbInputs
,
const
nvinfer1
::
PluginTensorDesc
*
outputs
,
int
nbOutputs
)
const
TRT_NOEXCEPT
override
{
return
ws_
;
}
int
enqueue
(
const
nvinfer1
::
PluginTensorDesc
*
inputDesc
,
const
nvinfer1
::
PluginTensorDesc
*
outputDesc
,
const
void
*
const
*
inputs
,
void
*
const
*
outputs
,
void
*
workspace
,
cudaStream_t
stream
)
TRT_NOEXCEPT
override
;
nvinfer1
::
DataType
getOutputDataType
(
int
index
,
const
nvinfer1
::
DataType
*
inputTypes
,
int
nbInputs
)
const
TRT_NOEXCEPT
override
;
void
destroy
()
TRT_NOEXCEPT
override
{
delete
this
;
}
void
terminate
()
TRT_NOEXCEPT
override
{};
private:
size_t
ws_
;
std
::
vector
<
float
>
scale_
;
std
::
vector
<
float
>
bias_
;
std
::
shared_ptr
<
void
>
scale_gpu_
;
std
::
shared_ptr
<
void
>
bias_gpu_
;
GroupNormNHWCParams
params_
;
int
groups_
;
float
eps_
;
bool
with_fp16_
;
};
class
SkipGroupnormActPluginDynamicCreator
:
public
TensorRTPluginCreator
{
public:
const
char
*
getPluginName
()
const
TRT_NOEXCEPT
override
{
return
"skip_groupnorm_act_plugin_dynamic"
;
}
const
char
*
getPluginVersion
()
const
TRT_NOEXCEPT
override
{
return
"1"
;
}
nvinfer1
::
IPluginV2
*
deserializePlugin
(
const
char
*
name
,
const
void
*
serial_data
,
size_t
serial_length
)
TRT_NOEXCEPT
override
{
return
new
SkipGroupnormActPluginDynamic
(
serial_data
,
serial_length
);
}
};
REGISTER_TRT_PLUGIN_V2
(
SkipGroupnormActPluginDynamicCreator
);
}
// namespace plugin
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/compat/group_norm.pbtxt
0 → 100644
浏览文件 @
591be3bd
type: "group_norm"
def {
inputs {
name: "X"
}
inputs {
name: "Scale"
}
inputs {
name: "Bias"
}
outputs {
name: "Y"
}
outputs {
name: "Mean"
}
outputs {
name: "Variance"
}
attrs {
name: "epsilon"
type: FLOAT
}
attrs {
name: "groups"
type: INT
}
attrs {
name: "data_layout"
type: STRING
}
}
extra {
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
paddle/fluid/operators/compat/silu.pbtxt
0 → 100644
浏览文件 @
591be3bd
type: "silu"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
}
extra {
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt
浏览文件 @
591be3bd
...
...
@@ -35,6 +35,10 @@ endif()
if
(
WIN32
)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_fused_token_prune"
)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_preln_groupnorm_act_fuse_pass"
)
list
(
REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_element_groupnorm_act_fuse_pass"
)
list
(
REMOVE_ITEM TEST_TRT_IR_PASSES
"test_trt_convert_fused_token_prune"
)
list
(
REMOVE_ITEM TEST_TRT_CONVERTER
"test_trt_convert_fused_token_prune"
)
endif
()
...
...
@@ -217,6 +221,10 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties
(
test_map_matmul_v2_to_mul_pass PROPERTIES TIMEOUT
120
)
set_tests_properties
(
test_map_matmul_to_mul_pass PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_element_groupnorm_act_fuse_pass
PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_preln_groupnorm_act_fuse_pass PROPERTIES TIMEOUT
120
)
endif
()
endif
()
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_element_groupnorm_act_fuse_pass.py
0 → 100644
浏览文件 @
591be3bd
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
functools
import
partial
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
import
paddle.inference
as
paddle_infer
class
TestElementGNActPass
(
PassAutoScanTest
):
#
# | | | |
# other_op1 other_op2 other_op1 other_op2
# | | fuse \ /
# elementwise_add -> skip_groupnorm_act
# | |
# groupnorm
# |
# silu
def
sample_predictor_configs
(
self
,
program_config
):
# trt dynamic_shape
config
=
self
.
create_trt_inference_config
()
config
.
enable_tensorrt_engine
(
max_batch_size
=
1
,
workspace_size
=
102400
,
min_subgraph_size
=
0
,
precision_mode
=
paddle_infer
.
PrecisionType
.
Half
,
use_static
=
False
,
use_calib_mode
=
False
,
)
config
.
set_trt_dynamic_shape_info
(
{
"input_data_x"
:
[
1
,
160
,
1
,
1
],
"input_data_y"
:
[
1
,
160
,
1
,
1
],
},
{
"input_data_x"
:
[
4
,
1280
,
64
,
64
],
"input_data_y"
:
[
4
,
1280
,
1
,
1
],
},
{
"input_data_x"
:
[
1
,
320
,
1
,
1
],
"input_data_y"
:
[
1
,
320
,
1
,
1
],
},
)
yield
config
,
[
'skip_groupnorm_act'
],
(
3e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
axis
=
draw
(
st
.
sampled_from
([
0
,
-
1
]))
epsilon
=
draw
(
st
.
floats
(
min_value
=
0.0000001
,
max_value
=
0.001
))
batch_size
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
))
groups
=
draw
(
st
.
sampled_from
([
4
,
8
,
16
,
32
]))
hw
=
draw
(
st
.
sampled_from
([
1
,
8
,
16
,
32
]))
channel
=
draw
(
st
.
sampled_from
([
320
,
1280
]))
def
generate_input_x
(
attrs
):
return
np
.
random
.
random
(
[
attrs
[
1
][
"batch_size"
],
*
attrs
[
1
][
"input_dim_x"
]]
).
astype
(
np
.
float32
)
def
generate_input_y
(
attrs
):
return
np
.
random
.
random
(
[
attrs
[
1
][
"batch_size"
],
*
attrs
[
1
][
"input_dim_y"
]]
).
astype
(
np
.
float32
)
def
generate_weight
(
attrs
):
return
np
.
random
.
random
(
attrs
[
1
][
'input_dim_x'
][
0
]).
astype
(
np
.
float32
)
attrs
=
[
{
'axis'
:
axis
,
'epsilon'
:
epsilon
,
'groups'
:
groups
,
},
{
'batch_size'
:
batch_size
,
'input_dim_x'
:
[
channel
,
hw
,
hw
],
'input_dim_y'
:
[
channel
,
1
,
1
],
},
]
elementwise_add_op
=
OpConfig
(
type
=
"elementwise_add"
,
inputs
=
{
"X"
:
[
"input_data_x"
],
"Y"
:
[
"input_data_y"
]},
outputs
=
{
"Out"
:
[
"ele_out"
]},
attrs
=
{
"axis"
:
attrs
[
0
][
'axis'
]},
)
group_norm_op
=
OpConfig
(
type
=
"group_norm"
,
inputs
=
{
"X"
:
[
"ele_out"
],
"Bias"
:
[
"group_norm_bias"
],
"Scale"
:
[
"group_norm_scale"
],
},
outputs
=
{
"Y"
:
[
"group_norm_output1"
],
"Mean"
:
[
"group_norm_output2"
],
"Variance"
:
[
"group_norm_output3"
],
},
attrs
=
{
"data_layout"
:
"NCHW"
,
"groups"
:
attrs
[
0
][
"groups"
],
"epsilon"
:
attrs
[
0
][
"epsilon"
],
},
)
silu_op
=
OpConfig
(
type
=
"silu"
,
inputs
=
{
"X"
:
[
"group_norm_output1"
],
},
outputs
=
{
"Out"
:
[
"silu_output"
],
},
)
program_config
=
ProgramConfig
(
ops
=
[
elementwise_add_op
,
group_norm_op
,
silu_op
,
],
weights
=
{
"group_norm_bias"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight
,
attrs
)
),
"group_norm_scale"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight
,
attrs
)
),
},
inputs
=
{
"input_data_x"
:
TensorConfig
(
data_gen
=
partial
(
generate_input_x
,
attrs
)
),
"input_data_y"
:
TensorConfig
(
data_gen
=
partial
(
generate_input_y
,
attrs
)
),
},
outputs
=
[
"silu_output"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
50
,
passes
=
[
"elementwise_groupnorm_act_pass"
],
max_duration
=
250
,
min_success_num
=
50
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ir/inference/test_preln_groupnorm_act_fuse_pass.py
0 → 100644
浏览文件 @
591be3bd
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
functools
import
partial
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
import
paddle.inference
as
paddle_infer
class
TestElementGNActPass
(
PassAutoScanTest
):
#
# | | | |
# other_op1 other_op2 other_op1 other_op2
# | | fuse \ /
# elementwise_add -> preln_groupnorm_act
# | | | |
# other_op3 groupnorm other_op3
# |
# silu
def
sample_predictor_configs
(
self
,
program_config
):
# trt dynamic_shape
config
=
self
.
create_trt_inference_config
()
config
.
enable_tensorrt_engine
(
max_batch_size
=
1
,
workspace_size
=
102400
,
min_subgraph_size
=
0
,
precision_mode
=
paddle_infer
.
PrecisionType
.
Half
,
use_static
=
False
,
use_calib_mode
=
False
,
)
config
.
set_trt_dynamic_shape_info
(
{
"input_data_x"
:
[
1
,
160
,
1
,
1
],
"input_data_y"
:
[
1
,
160
,
1
,
1
],
},
{
"input_data_x"
:
[
4
,
1280
,
64
,
64
],
"input_data_y"
:
[
4
,
1280
,
64
,
64
],
},
{
"input_data_x"
:
[
1
,
320
,
32
,
32
],
"input_data_y"
:
[
1
,
320
,
32
,
32
],
},
)
yield
config
,
[
'preln_groupnorm_act'
],
(
3e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
axis
=
draw
(
st
.
sampled_from
([
0
,
-
1
]))
epsilon
=
draw
(
st
.
floats
(
min_value
=
0.0000001
,
max_value
=
0.001
))
batch_size
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
))
groups
=
draw
(
st
.
sampled_from
([
4
,
8
,
16
,
32
]))
hw
=
draw
(
st
.
sampled_from
([
1
,
8
,
16
,
32
]))
channel
=
draw
(
st
.
sampled_from
([
320
,
1280
]))
def
generate_input_x
(
attrs
):
return
np
.
random
.
random
(
[
attrs
[
1
][
"batch_size"
],
*
attrs
[
1
][
"input_dim_x"
]]
).
astype
(
np
.
float32
)
def
generate_input_y
(
attrs
):
return
np
.
random
.
random
(
[
attrs
[
1
][
"batch_size"
],
*
attrs
[
1
][
"input_dim_y"
]]
).
astype
(
np
.
float32
)
def
generate_weight
(
attrs
):
return
np
.
random
.
random
(
attrs
[
1
][
'input_dim_x'
][
0
]).
astype
(
np
.
float32
)
attrs
=
[
{
'axis'
:
axis
,
'epsilon'
:
epsilon
,
'groups'
:
groups
,
},
{
'batch_size'
:
batch_size
,
'input_dim_x'
:
[
channel
,
hw
,
hw
],
'input_dim_y'
:
[
channel
,
hw
,
hw
],
},
]
elementwise_add_op
=
OpConfig
(
type
=
"elementwise_add"
,
inputs
=
{
"X"
:
[
"input_data_x"
],
"Y"
:
[
"input_data_y"
]},
outputs
=
{
"Out"
:
[
"ele_out"
]},
attrs
=
{
"axis"
:
attrs
[
0
][
'axis'
]},
)
group_norm_op
=
OpConfig
(
type
=
"group_norm"
,
inputs
=
{
"X"
:
[
"ele_out"
],
"Bias"
:
[
"group_norm_bias"
],
"Scale"
:
[
"group_norm_scale"
],
},
outputs
=
{
"Y"
:
[
"group_norm_output1"
],
"Mean"
:
[
"group_norm_output2"
],
"Variance"
:
[
"group_norm_output3"
],
},
attrs
=
{
"data_layout"
:
"NCHW"
,
"groups"
:
attrs
[
0
][
"groups"
],
"epsilon"
:
attrs
[
0
][
"epsilon"
],
},
)
silu_op
=
OpConfig
(
type
=
"silu"
,
inputs
=
{
"X"
:
[
"group_norm_output1"
],
},
outputs
=
{
"Out"
:
[
"silu_output"
],
},
)
program_config
=
ProgramConfig
(
ops
=
[
elementwise_add_op
,
group_norm_op
,
silu_op
,
],
weights
=
{
"group_norm_bias"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight
,
attrs
)
),
"group_norm_scale"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight
,
attrs
)
),
},
inputs
=
{
"input_data_x"
:
TensorConfig
(
data_gen
=
partial
(
generate_input_x
,
attrs
)
),
"input_data_y"
:
TensorConfig
(
data_gen
=
partial
(
generate_input_y
,
attrs
)
),
},
outputs
=
[
"ele_out"
,
"silu_output"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
50
,
passes
=
[
"preln_elementwise_groupnorm_act_pass"
],
max_duration
=
250
,
min_success_num
=
50
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录