From 206a33b3ae0caeb6d676d21334d084601457a2b7 Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Tue, 14 Dec 2021 18:35:07 +0800 Subject: [PATCH] add conv_gelu_mkldnn_fuse_pass (#38107) * add conv_gelu_mkldnn_fuse_pass * add post ops --- .../conv_activation_mkldnn_fuse_pass.cc | 31 ++++- .../mkldnn/conv_activation_mkldnn_fuse_pass.h | 9 ++ .../inference/api/paddle_pass_builder.cc | 13 +- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 6 + .../test_mkldnn_conv_gelu_fuse_pass.py | 129 ++++++++++++++++++ 5 files changed, 181 insertions(+), 7 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index c817400056c..cfd40435387 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -69,7 +69,15 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { desc->SetOutput("Output", std::vector({activation_out->Name()})); - desc->SetAttr("fuse_activation", activation_type()); + if (activation_type() == "gelu" && + activation->Op()->HasAttr("approximate")) { + bool approximate = + BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate")); + std::string type = approximate ? "_tanh" : "_erf"; + desc->SetAttr("fuse_activation", "gelu" + type); + } else { + desc->SetAttr("fuse_activation", activation_type()); + } // MKLDNN ops use alpha and beta as activation parameters but paddle ops are // not generalized @@ -240,6 +248,19 @@ Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() { .End(); } +Conv2DGeluFusePass::Conv2DGeluFusePass() { + AddOpCompat(OpCompat("gelu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("approximate") + .IsType() + .End(); +} + } // namespace ir } // namespace framework } // namespace paddle @@ -294,3 +315,11 @@ REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass) paddle::framework::compatible::OpVersionComparatorCombination() .LE("conv2d", 1) .EQ("hard_sigmoid", 0)); + +REGISTER_PASS(conv_gelu_mkldnn_fuse_pass, + paddle::framework::ir::Conv2DGeluFusePass); +REGISTER_PASS_CAPABILITY(conv_gelu_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("conv2d", 1) + .EQ("gelu", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h index eacde101d5a..b8279e48386 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h @@ -81,6 +81,15 @@ class Conv2DHardSigmoidFusePass : public ConvActivationFusePass { std::string activation_type() const { return "hard_sigmoid"; } }; +/* + * Fuse Conv and Gelu class + */ +class Conv2DGeluFusePass : public ConvActivationFusePass { + public: + Conv2DGeluFusePass(); + std::string activation_type() const { return "gelu"; } +}; + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index d571973a83f..de2de112344 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -244,12 +244,13 @@ void CpuPassStrategy::EnableMKLDNN() { "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass", - "conv_relu_mkldnn_fuse_pass", // - "conv_leaky_relu_mkldnn_fuse_pass", // - "conv_relu6_mkldnn_fuse_pass", // - "conv_swish_mkldnn_fuse_pass", // - "conv_hard_swish_mkldnn_fuse_pass", // - "conv_hard_sigmoid_mkldnn_fuse_pass", // + "conv_relu_mkldnn_fuse_pass", // + "conv_leaky_relu_mkldnn_fuse_pass", // + "conv_relu6_mkldnn_fuse_pass", // + "conv_swish_mkldnn_fuse_pass", // + "conv_hard_swish_mkldnn_fuse_pass", // + "conv_hard_sigmoid_mkldnn_fuse_pass", // + "conv_gelu_mkldnn_fuse_pass", "scale_matmul_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", // "reshape_transpose_matmul_v2_mkldnn_fuse_pass", // diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index d499b273885..d584da72393 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -510,6 +510,12 @@ class ConvMKLDNNHandlerT fuse_alpha, fuse_beta); post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_clip, 0.0f, 1.0f); + } else if (fuse_activation == "gelu_tanh") { + post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_tanh, + 0.0f, 0.0f); + } else if (fuse_activation == "gelu_erf") { + post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_erf, + 0.0f, 0.0f); } conv_attr.set_post_ops(post_operations); return conv_attr; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py new file mode 100644 index 00000000000..aa779f6ecbc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_gelu_fuse_pass.py @@ -0,0 +1,129 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +class TestConvGeluMkldnnFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + # If the problem has been fixed, the judgment + # needs to be deleted!!! + if attrs[0]['data_format'] == "NHWC": + return False + + return True + + def sample_program_config(self, draw): + data_format = draw(st.sampled_from(["NCHW", "NHWC"])) + dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) + padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) + groups = draw(st.sampled_from([1, 2, 4])) + paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]])) + strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) + approximate = draw(st.booleans()) + batch_size = draw(st.integers(min_value=1, max_value=4)) + + def generate_input(): + if data_format == "NCHW": + return np.random.random( + [batch_size, 48, 64, 64]).astype(np.float32) + else: + return np.random.random( + [batch_size, 64, 64, 48]).astype(np.float32) + + def generate_weight(): + return np.random.random( + [16, int(48 / groups), 3, 3]).astype(np.float32) + + ops_config = [{ + "op_type": "conv2d", + "op_inputs": { + "Input": ["input_data"], + "Filter": ["input_weight"] + }, + "op_outputs": { + "Output": ["conv_output"] + }, + "op_attrs": { + "data_format": data_format, + "dilations": dilations, + "padding_algorithm": padding_algorithm, + "groups": groups, + "paddings": paddings, + "strides": strides + } + }, { + "op_type": "gelu", + "op_inputs": { + "X": ["conv_output"] + }, + "op_outputs": { + "Out": ["gelu_output"] + }, + "op_attrs": { + "approximate": approximate, + }, + }] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "input_weight": TensorConfig(data_gen=partial(generate_weight)) + }, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input)), + }, + outputs=["gelu_output"]) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, ["conv2d"], (1e-5, 1e-5) + + # If the problem has been fixed, the judgment + # needs to be deleted!!! + def add_ignore_pass_case(self): + def teller1(program_config, predictor_config): + if program_config.ops[0].attrs['data_format'] == "NHWC": + return True + return False + + self.add_ignore_check_case( + teller1, SkipReasons.PASS_ACCURACY_ERROR, + "The output format of conv2d is wrong when data_format attribute is NHWC" + ) + + def test(self): + self.run_and_statis(quant=False, passes=["conv_gelu_mkldnn_fuse_pass"]) + + +if __name__ == "__main__": + unittest.main() -- GitLab