Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
adcb0039
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
adcb0039
编写于
1月 12, 2023
作者:
W
wenbin
提交者:
GitHub
1月 12, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
more preln_gn patterns (#49728)
* compile fix * fix compile * compile fix * add more preln
上级
a015f815
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
189 addition
and
24 deletion
+189
-24
paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.cc
...luid/framework/ir/preln_elementwise_groupnorm_act_pass.cc
+41
-20
paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h
...fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h
+2
-2
paddle/fluid/inference/tensorrt/convert/preln_groupnorm_act_op.cc
...luid/inference/tensorrt/convert/preln_groupnorm_act_op.cc
+2
-0
paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu
...nference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu
+1
-1
paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h
...inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h
+7
-1
python/paddle/fluid/tests/unittests/ir/inference/test_preln_groupnorm_act_fuse_pass.py
...ttests/ir/inference/test_preln_groupnorm_act_fuse_pass.py
+136
-0
未找到文件。
paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.cc
浏览文件 @
adcb0039
...
...
@@ -35,7 +35,7 @@ struct PrelnGroupNormAct : public PatternBase {
PrelnGroupNormAct
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"preln_groupnorm_act"
)
{}
void
operator
()(
PDNode
*
x
,
PDNode
*
y
);
void
operator
()(
PDNode
*
x
,
PDNode
*
y
,
bool
with_act
);
// declare operator node's name
PATTERN_DECL_NODE
(
elementwise
);
PATTERN_DECL_NODE
(
group_norm
);
...
...
@@ -49,7 +49,7 @@ struct PrelnGroupNormAct : public PatternBase {
PATTERN_DECL_NODE
(
act_out
);
};
void
PrelnGroupNormAct
::
operator
()(
PDNode
*
x
,
PDNode
*
y
)
{
void
PrelnGroupNormAct
::
operator
()(
PDNode
*
x
,
PDNode
*
y
,
bool
with_act
)
{
auto
*
elementwise
=
pattern
->
NewNode
(
elementwise_repr
())
->
assert_is_op
(
"elementwise_add"
);
...
...
@@ -74,26 +74,28 @@ void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y) {
auto
*
group_norm_out_var
=
pattern
->
NewNode
(
group_norm_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"group_norm"
,
"Y"
)
->
assert_is_op_input
(
"silu"
,
"X"
);
->
assert_is_op_output
(
"group_norm"
,
"Y"
);
// Add links for group_norm op.
group_norm
->
LinksFrom
(
{
elementwise_out_var
,
group_norm_bias_var
,
group_norm_scale_var
})
.
LinksTo
({
group_norm_out_var
});
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
"silu"
);
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"silu"
,
"Out"
);
act
->
LinksFrom
({
group_norm_out_var
}).
LinksTo
({
act_out
});
if
(
with_act
)
{
group_norm_out_var
->
assert_is_op_input
(
"silu"
,
"X"
);
auto
*
act
=
pattern
->
NewNode
(
act_repr
())
->
assert_is_op
(
"silu"
);
auto
*
act_out
=
pattern
->
NewNode
(
act_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"silu"
,
"Out"
);
act
->
LinksFrom
({
group_norm_out_var
}).
LinksTo
({
act_out
});
}
}
}
// namespace patterns
int
PrelnGroupNormActFusePass
::
ApplyGNSiluPattern
(
ir
::
Graph
*
graph
)
const
{
int
PrelnGroupNormActFusePass
::
ApplyAddGNPattern
(
ir
::
Graph
*
graph
,
bool
with_act
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
"preln_groupnorm_silu_fuse"
,
graph
);
...
...
@@ -118,7 +120,7 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
patterns
::
PrelnGroupNormAct
fused_pattern
(
gpd
.
mutable_pattern
(),
"preln_groupnorm_act_fuse"
);
fused_pattern
(
x
,
y
);
fused_pattern
(
x
,
y
,
with_act
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
...
...
@@ -129,6 +131,9 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
VLOG
(
4
)
<<
"handle preln groupnorm act fuse"
;
Node
*
act
=
nullptr
;
Node
*
act_out
=
nullptr
;
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise
,
elementwise
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_out
,
elementwise_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
group_norm
,
group_norm
,
fused_pattern
);
...
...
@@ -136,8 +141,12 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH
(
group_norm_scale
,
group_norm_scale
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
group_norm_out
,
group_norm_out
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act
,
act
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
act_out
,
act_out
,
fused_pattern
);
if
(
with_act
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
tmp_act
,
act
,
fused_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
tmp_act_out
,
act_out
,
fused_pattern
);
act
=
tmp_act
;
act_out
=
tmp_act_out
;
}
if
(
!
IsCompat
(
subgraph
,
graph
))
{
LOG
(
WARNING
)
<<
"preln groupnorm act pass in op compat failed."
;
...
...
@@ -150,8 +159,13 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
new_desc
.
SetType
(
"preln_groupnorm_act"
);
new_desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
new_desc
.
SetInput
(
"Y"
,
{
subgraph
.
at
(
y
)
->
Name
()});
new_desc
.
SetAttr
(
"with_silu"
,
with_act
);
new_desc
.
SetOutput
(
"Out_0"
,
{
elementwise_out
->
Name
()});
new_desc
.
SetOutput
(
"Out_1"
,
{
act_out
->
Name
()});
if
(
with_act
)
{
new_desc
.
SetOutput
(
"Out_1"
,
{
act_out
->
Name
()});
}
else
{
new_desc
.
SetOutput
(
"Out_1"
,
{
group_norm_out
->
Name
()});
}
new_desc
.
RemoveOutput
(
"Y"
);
new_desc
.
Flush
();
...
...
@@ -159,15 +173,21 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
del_node_set
.
insert
(
elementwise
);
del_node_set
.
insert
(
group_norm
);
del_node_set
.
insert
(
group_norm_out
);
del_node_set
.
insert
(
act
);
if
(
with_act
)
{
del_node_set
.
insert
(
act
);
del_node_set
.
insert
(
group_norm_out
);
}
GraphSafeRemoveNodes
(
graph
,
del_node_set
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
fused_node
);
IR_NODE_LINK_TO
(
subgraph
.
at
(
y
),
fused_node
);
IR_NODE_LINK_TO
(
group_norm_scale
,
fused_node
);
IR_NODE_LINK_TO
(
group_norm_bias
,
fused_node
);
IR_NODE_LINK_TO
(
fused_node
,
act_out
);
if
(
with_act
)
{
IR_NODE_LINK_TO
(
fused_node
,
act_out
);
}
else
{
IR_NODE_LINK_TO
(
fused_node
,
group_norm_out
);
}
IR_NODE_LINK_TO
(
fused_node
,
elementwise_out
);
found_subgraph_count
++
;
};
...
...
@@ -178,7 +198,8 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
void
PrelnGroupNormActFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
"preln_groupnorm_act_fuse_pass"
,
graph
);
int
found_subgraph_count
=
ApplyGNSiluPattern
(
graph
);
int
found_subgraph_count
=
ApplyAddGNPattern
(
graph
,
true
);
found_subgraph_count
+=
ApplyAddGNPattern
(
graph
,
false
);
AddStatis
(
found_subgraph_count
);
}
...
...
paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h
浏览文件 @
adcb0039
...
...
@@ -25,7 +25,7 @@ namespace ir {
// | | -> preln_gn_act
// other op group_norm | |
// | other op
// silu
// silu
(optional)
// |
class
Graph
;
...
...
@@ -88,7 +88,7 @@ class PrelnGroupNormActFusePass : public FusePassBase {
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
int
Apply
GNSiluPattern
(
ir
::
Graph
*
graph
)
const
;
int
Apply
AddGNPattern
(
ir
::
Graph
*
graph
,
bool
with_act
)
const
;
};
}
// namespace ir
...
...
paddle/fluid/inference/tensorrt/convert/preln_groupnorm_act_op.cc
浏览文件 @
adcb0039
...
...
@@ -45,6 +45,7 @@ class PrelnGroupnormActOpConverter : public OpConverter {
int
groups
=
PADDLE_GET_CONST
(
int
,
op_desc
.
GetAttr
(
"groups"
));
float
epsilon
=
PADDLE_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"epsilon"
));
bool
with_silu
=
PADDLE_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"with_silu"
));
std
::
string
scale_name
=
op_desc
.
Input
(
"Scale"
).
front
();
std
::
string
bias_name
=
op_desc
.
Input
(
"Bias"
).
front
();
...
...
@@ -75,6 +76,7 @@ class PrelnGroupnormActOpConverter : public OpConverter {
bias_weights
.
get
().
count
,
epsilon
,
groups
,
with_silu
,
with_fp16
);
nvinfer1
::
ILayer
*
groupnorm_layer
=
engine_
->
AddDynamicPlugin
(
inputs
.
data
(),
2
,
plugin
);
...
...
paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu
浏览文件 @
adcb0039
...
...
@@ -431,7 +431,7 @@ int PrelnGroupnormActPluginDynamic::enqueue(
if
(
cPerBlock
>
input_desc
[
0
].
dims
.
d
[
1
])
{
cPerBlock
=
8
;
}
params_
.
withSwish
=
true
;
params_
.
withSwish
=
with_silu_
;
params_
.
dst
=
static_cast
<
half
*>
(
outputs
[
1
]);
params_
.
eleOut
=
static_cast
<
half
*>
(
outputs
[
0
]);
params_
.
srcX
=
static_cast
<
half
const
*>
(
inputs
[
0
]);
...
...
paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h
浏览文件 @
adcb0039
...
...
@@ -36,6 +36,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
const
int
bias_num
,
float
eps
,
int
groups
,
bool
with_silu
,
bool
with_fp16
,
std
::
shared_ptr
<
void
>
scale_gpu
=
nullptr
,
std
::
shared_ptr
<
void
>
bias_gpu
=
nullptr
)
...
...
@@ -43,6 +44,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
bias_gpu_
(
bias_gpu
),
groups_
(
groups
),
eps_
(
eps
),
with_silu_
(
with_silu
),
with_fp16_
(
with_fp16
)
{
scale_
.
resize
(
scale_num
);
bias_
.
resize
(
bias_num
);
...
...
@@ -69,6 +71,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
bias_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
eps_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
groups_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
with_silu_
);
DeserializeValue
(
&
serialData
,
&
serialLength
,
&
with_fp16_
);
{
...
...
@@ -97,6 +100,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
bias_
.
size
(),
eps_
,
groups_
,
with_silu_
,
with_fp16_
,
scale_gpu_
,
bias_gpu_
);
...
...
@@ -112,13 +116,14 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
size_t
getSerializationSize
()
const
TRT_NOEXCEPT
override
{
return
SerializedSize
(
scale_
)
+
SerializedSize
(
bias_
)
+
SerializedSize
(
eps_
)
+
SerializedSize
(
groups_
)
+
SerializedSize
(
with_fp16_
);
SerializedSize
(
with_
silu_
)
+
SerializedSize
(
with_
fp16_
);
}
void
serialize
(
void
*
buffer
)
const
TRT_NOEXCEPT
override
{
SerializeValue
(
&
buffer
,
scale_
);
SerializeValue
(
&
buffer
,
bias_
);
SerializeValue
(
&
buffer
,
eps_
);
SerializeValue
(
&
buffer
,
groups_
);
SerializeValue
(
&
buffer
,
with_silu_
);
SerializeValue
(
&
buffer
,
with_fp16_
);
}
nvinfer1
::
DimsExprs
getOutputDimensions
(
...
...
@@ -171,6 +176,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
GroupNormNHWCParams
params_
;
int
groups_
;
float
eps_
;
bool
with_silu_
;
bool
with_fp16_
;
};
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_preln_groupnorm_act_fuse_pass.py
浏览文件 @
adcb0039
...
...
@@ -169,5 +169,141 @@ class TestElementGNActPass(PassAutoScanTest):
)
class
TestElementGNNoActPass
(
PassAutoScanTest
):
#
# | | | |
# other_op1 other_op2 other_op1 other_op2
# | | fuse \ /
# elementwise_add -> preln_groupnorm_act
# | | | |
# other_op3 groupnorm other_op3
# |
#
def
sample_predictor_configs
(
self
,
program_config
):
# trt dynamic_shape
config
=
self
.
create_trt_inference_config
()
config
.
enable_tensorrt_engine
(
max_batch_size
=
1
,
workspace_size
=
102400
,
min_subgraph_size
=
0
,
precision_mode
=
paddle_infer
.
PrecisionType
.
Half
,
use_static
=
False
,
use_calib_mode
=
False
,
)
config
.
set_trt_dynamic_shape_info
(
{
"input_data_x"
:
[
1
,
160
,
1
,
1
],
"input_data_y"
:
[
1
,
160
,
1
,
1
],
},
{
"input_data_x"
:
[
4
,
1280
,
64
,
64
],
"input_data_y"
:
[
4
,
1280
,
64
,
64
],
},
{
"input_data_x"
:
[
1
,
320
,
32
,
32
],
"input_data_y"
:
[
1
,
320
,
32
,
32
],
},
)
yield
config
,
[
'preln_groupnorm_act'
],
(
3e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
axis
=
draw
(
st
.
sampled_from
([
0
,
-
1
]))
epsilon
=
draw
(
st
.
floats
(
min_value
=
0.0000001
,
max_value
=
0.001
))
batch_size
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
))
groups
=
draw
(
st
.
sampled_from
([
4
,
8
,
16
,
32
]))
hw
=
draw
(
st
.
sampled_from
([
1
,
8
,
16
,
32
]))
channel
=
draw
(
st
.
sampled_from
([
320
,
1280
]))
def
generate_input_x
(
attrs
):
return
np
.
random
.
random
(
[
attrs
[
1
][
"batch_size"
],
*
attrs
[
1
][
"input_dim_x"
]]
).
astype
(
np
.
float32
)
def
generate_input_y
(
attrs
):
return
np
.
random
.
random
(
[
attrs
[
1
][
"batch_size"
],
*
attrs
[
1
][
"input_dim_y"
]]
).
astype
(
np
.
float32
)
def
generate_weight
(
attrs
):
return
np
.
random
.
random
(
attrs
[
1
][
'input_dim_x'
][
0
]).
astype
(
np
.
float32
)
attrs
=
[
{
'axis'
:
axis
,
'epsilon'
:
epsilon
,
'groups'
:
groups
,
},
{
'batch_size'
:
batch_size
,
'input_dim_x'
:
[
channel
,
hw
,
hw
],
'input_dim_y'
:
[
channel
,
hw
,
hw
],
},
]
elementwise_add_op
=
OpConfig
(
type
=
"elementwise_add"
,
inputs
=
{
"X"
:
[
"input_data_x"
],
"Y"
:
[
"input_data_y"
]},
outputs
=
{
"Out"
:
[
"ele_out"
]},
attrs
=
{
"axis"
:
attrs
[
0
][
'axis'
]},
)
group_norm_op
=
OpConfig
(
type
=
"group_norm"
,
inputs
=
{
"X"
:
[
"ele_out"
],
"Bias"
:
[
"group_norm_bias"
],
"Scale"
:
[
"group_norm_scale"
],
},
outputs
=
{
"Y"
:
[
"group_norm_output1"
],
"Mean"
:
[
"group_norm_output2"
],
"Variance"
:
[
"group_norm_output3"
],
},
attrs
=
{
"data_layout"
:
"NCHW"
,
"groups"
:
attrs
[
0
][
"groups"
],
"epsilon"
:
attrs
[
0
][
"epsilon"
],
},
)
program_config
=
ProgramConfig
(
ops
=
[
elementwise_add_op
,
group_norm_op
,
],
weights
=
{
"group_norm_bias"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight
,
attrs
)
),
"group_norm_scale"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight
,
attrs
)
),
},
inputs
=
{
"input_data_x"
:
TensorConfig
(
data_gen
=
partial
(
generate_input_x
,
attrs
)
),
"input_data_y"
:
TensorConfig
(
data_gen
=
partial
(
generate_input_y
,
attrs
)
),
},
outputs
=
[
"ele_out"
,
"group_norm_output1"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
50
,
passes
=
[
"preln_elementwise_groupnorm_act_pass"
],
max_duration
=
250
,
min_success_num
=
50
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录