Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
206a33b3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
206a33b3
编写于
12月 14, 2021
作者:
B
baoachun
提交者:
GitHub
12月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add conv_gelu_mkldnn_fuse_pass (#38107)
* add conv_gelu_mkldnn_fuse_pass * add post ops
上级
fff6e77c
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
181 addition
and
7 deletion
+181
-7
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
...d/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
+30
-1
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
...id/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
+9
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+7
-6
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
+6
-0
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py
...unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py
+129
-0
未找到文件。
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc
浏览文件 @
206a33b3
...
...
@@ -69,7 +69,15 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
desc
->
SetOutput
(
"Output"
,
std
::
vector
<
std
::
string
>
({
activation_out
->
Name
()}));
if
(
activation_type
()
==
"gelu"
&&
activation
->
Op
()
->
HasAttr
(
"approximate"
))
{
bool
approximate
=
BOOST_GET_CONST
(
bool
,
activation
->
Op
()
->
GetAttr
(
"approximate"
));
std
::
string
type
=
approximate
?
"_tanh"
:
"_erf"
;
desc
->
SetAttr
(
"fuse_activation"
,
"gelu"
+
type
);
}
else
{
desc
->
SetAttr
(
"fuse_activation"
,
activation_type
());
}
// MKLDNN ops use alpha and beta as activation parameters but paddle ops are
// not generalized
...
...
@@ -240,6 +248,19 @@ Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() {
.
End
();
}
Conv2DGeluFusePass
::
Conv2DGeluFusePass
()
{
AddOpCompat
(
OpCompat
(
"gelu"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"approximate"
)
.
IsType
<
bool
>
()
.
End
();
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
...
...
@@ -294,3 +315,11 @@ REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass)
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"conv2d"
,
1
)
.
EQ
(
"hard_sigmoid"
,
0
));
REGISTER_PASS
(
conv_gelu_mkldnn_fuse_pass
,
paddle
::
framework
::
ir
::
Conv2DGeluFusePass
);
REGISTER_PASS_CAPABILITY
(
conv_gelu_mkldnn_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"conv2d"
,
1
)
.
EQ
(
"gelu"
,
0
));
paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h
浏览文件 @
206a33b3
...
...
@@ -81,6 +81,15 @@ class Conv2DHardSigmoidFusePass : public ConvActivationFusePass {
std
::
string
activation_type
()
const
{
return
"hard_sigmoid"
;
}
};
/*
* Fuse Conv and Gelu class
*/
class
Conv2DGeluFusePass
:
public
ConvActivationFusePass
{
public:
Conv2DGeluFusePass
();
std
::
string
activation_type
()
const
{
return
"gelu"
;
}
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
206a33b3
...
...
@@ -250,6 +250,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_swish_mkldnn_fuse_pass"
,
//
"conv_hard_swish_mkldnn_fuse_pass"
,
//
"conv_hard_sigmoid_mkldnn_fuse_pass"
,
//
"conv_gelu_mkldnn_fuse_pass"
,
"scale_matmul_fuse_pass"
,
//
"reshape_transpose_matmul_mkldnn_fuse_pass"
,
//
"reshape_transpose_matmul_v2_mkldnn_fuse_pass"
,
//
...
...
paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc
浏览文件 @
206a33b3
...
...
@@ -510,6 +510,12 @@ class ConvMKLDNNHandlerT
fuse_alpha
,
fuse_beta
);
post_operations
.
append_eltwise
(
scale
,
dnnl
::
algorithm
::
eltwise_clip
,
0.0
f
,
1.0
f
);
}
else
if
(
fuse_activation
==
"gelu_tanh"
)
{
post_operations
.
append_eltwise
(
scale
,
dnnl
::
algorithm
::
eltwise_gelu_tanh
,
0.0
f
,
0.0
f
);
}
else
if
(
fuse_activation
==
"gelu_erf"
)
{
post_operations
.
append_eltwise
(
scale
,
dnnl
::
algorithm
::
eltwise_gelu_erf
,
0.0
f
,
0.0
f
);
}
conv_attr
.
set_post_ops
(
post_operations
);
return
conv_attr
;
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py
0 → 100644
浏览文件 @
206a33b3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
auto_scan_test
import
PassAutoScanTest
,
SkipReasons
from
program_config
import
TensorConfig
,
ProgramConfig
import
numpy
as
np
import
paddle.inference
as
paddle_infer
from
functools
import
partial
from
typing
import
Optional
,
List
,
Callable
,
Dict
,
Any
,
Set
import
unittest
import
hypothesis
from
hypothesis
import
given
,
settings
,
seed
,
example
,
assume
import
hypothesis.strategies
as
st
class
TestConvGeluMkldnnFusePass
(
PassAutoScanTest
):
def
is_program_valid
(
self
,
program_config
:
ProgramConfig
)
->
bool
:
attrs
=
[
program_config
.
ops
[
i
].
attrs
for
i
in
range
(
len
(
program_config
.
ops
))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if
attrs
[
0
][
'data_format'
]
==
"NHWC"
:
return
False
return
True
def
sample_program_config
(
self
,
draw
):
data_format
=
draw
(
st
.
sampled_from
([
"NCHW"
,
"NHWC"
]))
dilations
=
draw
(
st
.
sampled_from
([[
1
,
1
],
[
2
,
2
],
[
1
,
2
]]))
padding_algorithm
=
draw
(
st
.
sampled_from
([
"EXPLICIT"
,
"SAME"
,
"VALID"
]))
groups
=
draw
(
st
.
sampled_from
([
1
,
2
,
4
]))
paddings
=
draw
(
st
.
sampled_from
([[
0
,
3
],
[
1
,
2
,
3
,
4
]]))
strides
=
draw
(
st
.
sampled_from
([[
1
,
1
],
[
2
,
2
],
[
1
,
2
]]))
approximate
=
draw
(
st
.
booleans
())
batch_size
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
))
def
generate_input
():
if
data_format
==
"NCHW"
:
return
np
.
random
.
random
(
[
batch_size
,
48
,
64
,
64
]).
astype
(
np
.
float32
)
else
:
return
np
.
random
.
random
(
[
batch_size
,
64
,
64
,
48
]).
astype
(
np
.
float32
)
def
generate_weight
():
return
np
.
random
.
random
(
[
16
,
int
(
48
/
groups
),
3
,
3
]).
astype
(
np
.
float32
)
ops_config
=
[{
"op_type"
:
"conv2d"
,
"op_inputs"
:
{
"Input"
:
[
"input_data"
],
"Filter"
:
[
"input_weight"
]
},
"op_outputs"
:
{
"Output"
:
[
"conv_output"
]
},
"op_attrs"
:
{
"data_format"
:
data_format
,
"dilations"
:
dilations
,
"padding_algorithm"
:
padding_algorithm
,
"groups"
:
groups
,
"paddings"
:
paddings
,
"strides"
:
strides
}
},
{
"op_type"
:
"gelu"
,
"op_inputs"
:
{
"X"
:
[
"conv_output"
]
},
"op_outputs"
:
{
"Out"
:
[
"gelu_output"
]
},
"op_attrs"
:
{
"approximate"
:
approximate
,
},
}]
ops
=
self
.
generate_op_config
(
ops_config
)
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{
"input_weight"
:
TensorConfig
(
data_gen
=
partial
(
generate_weight
))
},
inputs
=
{
"input_data"
:
TensorConfig
(
data_gen
=
partial
(
generate_input
)),
},
outputs
=
[
"gelu_output"
])
return
program_config
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_mkldnn
=
True
)
yield
config
,
[
"conv2d"
],
(
1e-5
,
1e-5
)
# If the problem has been fixed, the judgment
# needs to be deleted!!!
def
add_ignore_pass_case
(
self
):
def
teller1
(
program_config
,
predictor_config
):
if
program_config
.
ops
[
0
].
attrs
[
'data_format'
]
==
"NHWC"
:
return
True
return
False
self
.
add_ignore_check_case
(
teller1
,
SkipReasons
.
PASS_ACCURACY_ERROR
,
"The output format of conv2d is wrong when data_format attribute is NHWC"
)
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
passes
=
[
"conv_gelu_mkldnn_fuse_pass"
])
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录