未验证 提交 206a33b3 编写于 作者: B baoachun 提交者: GitHub

add conv_gelu_mkldnn_fuse_pass (#38107)

* add conv_gelu_mkldnn_fuse_pass

* add post ops
上级 fff6e77c
...@@ -69,7 +69,15 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -69,7 +69,15 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
desc->SetOutput("Output", desc->SetOutput("Output",
std::vector<std::string>({activation_out->Name()})); std::vector<std::string>({activation_out->Name()}));
if (activation_type() == "gelu" &&
activation->Op()->HasAttr("approximate")) {
bool approximate =
BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate"));
std::string type = approximate ? "_tanh" : "_erf";
desc->SetAttr("fuse_activation", "gelu" + type);
} else {
desc->SetAttr("fuse_activation", activation_type()); desc->SetAttr("fuse_activation", activation_type());
}
// MKLDNN ops use alpha and beta as activation parameters but paddle ops are // MKLDNN ops use alpha and beta as activation parameters but paddle ops are
// not generalized // not generalized
...@@ -240,6 +248,19 @@ Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() { ...@@ -240,6 +248,19 @@ Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() {
.End(); .End();
} }
Conv2DGeluFusePass::Conv2DGeluFusePass() {
AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -294,3 +315,11 @@ REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass) ...@@ -294,3 +315,11 @@ REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass)
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1) .LE("conv2d", 1)
.EQ("hard_sigmoid", 0)); .EQ("hard_sigmoid", 0));
REGISTER_PASS(conv_gelu_mkldnn_fuse_pass,
paddle::framework::ir::Conv2DGeluFusePass);
REGISTER_PASS_CAPABILITY(conv_gelu_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("gelu", 0));
...@@ -81,6 +81,15 @@ class Conv2DHardSigmoidFusePass : public ConvActivationFusePass { ...@@ -81,6 +81,15 @@ class Conv2DHardSigmoidFusePass : public ConvActivationFusePass {
std::string activation_type() const { return "hard_sigmoid"; } std::string activation_type() const { return "hard_sigmoid"; }
}; };
/*
* Fuse Conv and Gelu class
*/
class Conv2DGeluFusePass : public ConvActivationFusePass {
public:
Conv2DGeluFusePass();
std::string activation_type() const { return "gelu"; }
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -250,6 +250,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -250,6 +250,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_swish_mkldnn_fuse_pass", // "conv_swish_mkldnn_fuse_pass", //
"conv_hard_swish_mkldnn_fuse_pass", // "conv_hard_swish_mkldnn_fuse_pass", //
"conv_hard_sigmoid_mkldnn_fuse_pass", // "conv_hard_sigmoid_mkldnn_fuse_pass", //
"conv_gelu_mkldnn_fuse_pass",
"scale_matmul_fuse_pass", // "scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", //
"reshape_transpose_matmul_v2_mkldnn_fuse_pass", // "reshape_transpose_matmul_v2_mkldnn_fuse_pass", //
......
...@@ -510,6 +510,12 @@ class ConvMKLDNNHandlerT ...@@ -510,6 +510,12 @@ class ConvMKLDNNHandlerT
fuse_alpha, fuse_beta); fuse_alpha, fuse_beta);
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_clip, 0.0f, post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_clip, 0.0f,
1.0f); 1.0f);
} else if (fuse_activation == "gelu_tanh") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_tanh,
0.0f, 0.0f);
} else if (fuse_activation == "gelu_erf") {
post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_erf,
0.0f, 0.0f);
} }
conv_attr.set_post_ops(post_operations); conv_attr.set_post_ops(post_operations);
return conv_attr; return conv_attr;
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume
import hypothesis.strategies as st
class TestConvGeluMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# If the problem has been fixed, the judgment
# needs to be deleted!!!
if attrs[0]['data_format'] == "NHWC":
return False
return True
def sample_program_config(self, draw):
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
groups = draw(st.sampled_from([1, 2, 4]))
paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]]))
strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
approximate = draw(st.booleans())
batch_size = draw(st.integers(min_value=1, max_value=4))
def generate_input():
if data_format == "NCHW":
return np.random.random(
[batch_size, 48, 64, 64]).astype(np.float32)
else:
return np.random.random(
[batch_size, 64, 64, 48]).astype(np.float32)
def generate_weight():
return np.random.random(
[16, int(48 / groups), 3, 3]).astype(np.float32)
ops_config = [{
"op_type": "conv2d",
"op_inputs": {
"Input": ["input_data"],
"Filter": ["input_weight"]
},
"op_outputs": {
"Output": ["conv_output"]
},
"op_attrs": {
"data_format": data_format,
"dilations": dilations,
"padding_algorithm": padding_algorithm,
"groups": groups,
"paddings": paddings,
"strides": strides
}
}, {
"op_type": "gelu",
"op_inputs": {
"X": ["conv_output"]
},
"op_outputs": {
"Out": ["gelu_output"]
},
"op_attrs": {
"approximate": approximate,
},
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"input_weight": TensorConfig(data_gen=partial(generate_weight))
},
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input)),
},
outputs=["gelu_output"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["conv2d"], (1e-5, 1e-5)
# If the problem has been fixed, the judgment
# needs to be deleted!!!
def add_ignore_pass_case(self):
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs['data_format'] == "NHWC":
return True
return False
self.add_ignore_check_case(
teller1, SkipReasons.PASS_ACCURACY_ERROR,
"The output format of conv2d is wrong when data_format attribute is NHWC"
)
def test(self):
self.run_and_statis(quant=False, passes=["conv_gelu_mkldnn_fuse_pass"])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册