From 3333a43986d72c39e6685abef28b2be43fc0fc4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Tue, 12 Jul 2022 10:05:02 +0200 Subject: [PATCH] matmul+activation fuse pass (#43519) * add method for post ops * format code * gpd * format style * add matmul+act test * implement matmul+activation * whitespaces * code style * python code format * Increase UT timeout * code format * update style * generalize activation fuse passes * change order * Unify activation GPD * Revert changes with op_act * remove softmax mkldnn attrs * set common name for act attributes * whitespace * append postops by helper function * ut style * revert changes related to quantization * Reduce redundancy * reduce number of parameters * trigger CI * validate attribute * trim unit test --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../conv_activation_mkldnn_fuse_pass.cc | 1 - .../matmul_activation_mkldnn_fuse_pass.cc | 281 ++++++++++++++++++ .../matmul_activation_mkldnn_fuse_pass.h | 41 +++ ...plus_activation_mkldnn_fuse_pass_tester.cc | 4 +- .../inference/api/paddle_pass_builder.cc | 1 + .../operators/mkldnn/matmul_mkldnn_op.cc | 3 + ...test_mkldnn_matmul_activation_fuse_pass.py | 127 ++++++++ 8 files changed, 456 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 2e4b73c6ac..d31555bf72 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -212,6 +212,7 @@ if(WITH_MKLDNN) pass_library(shuffle_channel_mkldnn_detect_pass inference DIR mkldnn) pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn) + pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 8c140e8132..bd07967757 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -26,7 +26,6 @@ using string::PrettyLogDetail; void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { auto act_types = paddle::platform::GetSupportedActivations(); - std::vector conv_types = {"conv2d"}; for (const auto& conv_type : conv_types) diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc new file mode 100644 index 0000000000..80f49c97e8 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc @@ -0,0 +1,281 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h" + +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { + auto act_types = paddle::platform::GetSupportedActivations(); + std::vector matmul_types = {"matmul"}; + + for (const auto& matmul_type : matmul_types) + for (auto& act_type : act_types) { + FuseMatmulAct(graph, matmul_type, act_type); + } +} + +void MatmulActivationMkldnnFusePass::FuseMatmulAct( + Graph* graph, const std::string& matmul_type, std::string& act_type) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init(matmul_type + "_" + act_type + "_mkldnn_fuse_pass", graph); + + GraphPatternDetector gpd; + patterns::OperatorActivation matmul_act_pattern( + gpd.mutable_pattern(), "matmul_activation_mkldnn_fuse"); + matmul_act_pattern(matmul_type, act_type); + + int found_matmul_activation_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle " + matmul_type + "+" + act_type + " fuse"; + + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "matmul_activation_mkldnn_fuse_pass op compat failed."; + return; + } + + GET_IR_NODE_FROM_SUBGRAPH(matmul, preceding_op, matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, preceding_op_out, matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(activation, activation, matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + activation_out, activation_out, matmul_act_pattern); + + OpDesc* matmul_op = matmul->Op(); + OpDesc* act_op = activation->Op(); + + auto attr_map = paddle::platform::GetAttributeMap(act_type); + for (const auto& attrs : attr_map) { + if (act_op->HasAttr(attrs.first)) { + matmul_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first)); + } + } + + if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) { + act_type = BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate")) + ? "gelu_tanh" + : "gelu_erf"; + } + matmul_op->SetAttr("fuse_activation", act_type); + matmul_op->SetOutput("Out", {activation_out->Name()}); + + IR_NODE_LINK_TO(matmul, activation_out); + GraphSafeRemoveNodes(graph, {activation, matmul_out}); + found_matmul_activation_count++; + }; + + gpd(graph, handler); + AddStatis(found_matmul_activation_count); + if (!Has("disable_logs") || !Get("disable_logs")) { + PrettyLogDetail("--- fused %d matmul with %s activation", + found_matmul_activation_count, + act_type); + } +} + +MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() + .End() + .AddAttr("transpose_X") + .IsType() + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("abs")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("clip")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("min") + .End() + .AddAttr("max") + .End(); + + AddOpCompat(OpCompat("gelu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("approximate") + .IsType() + .End(); + + AddOpCompat(OpCompat("hard_sigmoid")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("slope") + .IsOptional() + .IsType() + .End() + .AddAttr("offset") + .IsOptional() + .IsType() + .End(); + + AddOpCompat(OpCompat("hard_swish")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("threshold") + .IsOptional() + .IsType() + .End() + .AddAttr("scale") + .IsOptional() + .IsType() + .End() + .AddAttr("offset") + .IsOptional() + .IsType() + .End(); + + AddOpCompat(OpCompat("leaky_relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() + .End(); + + AddOpCompat(OpCompat("mish")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("relu6")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("threshold") + .IsType() + .End(); + + AddOpCompat(OpCompat("sigmoid")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("sqrt")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + AddOpCompat(OpCompat("swish")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("beta") + .IsType() + .End(); + + AddOpCompat(OpCompat("tanh")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(matmul_activation_mkldnn_fuse_pass, + paddle::framework::ir::MatmulActivationMkldnnFusePass); + +REGISTER_PASS_CAPABILITY(matmul_activation_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("matmul", 1) + .EQ("abs", 0) + .LE("clip", 1) + .EQ("gelu", 0) + .EQ("hard_sigmoid", 0) + .LE("hard_swish", 0) + .LE("leaky_relu", 1) + .LE("mish", 1) + .EQ("relu", 0) + .EQ("relu6", 0) + .EQ("sigmoid", 0) + .EQ("sqrt", 0) + .EQ("swish", 0) + .EQ("tanh", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h new file mode 100644 index 0000000000..ebef63e292 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +class MatmulActivationMkldnnFusePass : public FusePassBase { + public: + MatmulActivationMkldnnFusePass(); + virtual ~MatmulActivationMkldnnFusePass() {} + + protected: + void ApplyImpl(Graph *graph) const override; + + void FuseMatmulAct(Graph *graph, + const std::string &matmul_type, + std::string &act_type) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass_tester.cc index 2fbb46e32d..afe3d75fd2 100644 --- a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass_tester.cc @@ -50,9 +50,9 @@ void MainTest(const std::string& activation_type) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("fuse_activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto activation_type = - BOOST_GET_CONST(std::string, op->GetAttr("fuse_activation_type")); + BOOST_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(activation_type.compare(activation_type), 0); } } diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 0d918446ea..3642a28790 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -302,6 +302,7 @@ void CpuPassStrategy::EnableMKLDNN() { "softplus_activation_mkldnn_fuse_pass", // "shuffle_channel_mkldnn_detect_pass", // "elt_act_mkldnn_fuse_pass", // + "matmul_activation_mkldnn_fuse_pass", // // TODO(intel): Please fix the bug on windows. // https://github.com/PaddlePaddle/Paddle/issues/29710 // "mkldnn_inplace_pass", // This pass should be activated after diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 9ab09c3f3c..912b1be813 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" using dnnl::memory; using dnnl::primitive; @@ -453,6 +454,8 @@ class MatMulMKLDNNHandler matmul_attrs.set_output_scales(0, {scale_out}); } + paddle::platform::AppendActivation(ctx, post_operations); + matmul_attrs.set_post_ops(post_operations); return matmul_attrs; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py new file mode 100644 index 0000000000..20028fb335 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py @@ -0,0 +1,127 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest +import hypothesis.strategies as st + + +class TestMatmulActivationMkldnnFusePass(PassAutoScanTest): + + def sample_program_config(self, draw): + transpose_X = draw(st.booleans()) + transpose_Y = draw(st.booleans()) + alpha = draw(st.sampled_from([1, 2])) + batch_size = draw(st.sampled_from([4])) + channel = draw(st.sampled_from([8])) + input_dim = draw(st.sampled_from([32])) + activation_type = draw( + st.sampled_from([ + 'relu', 'gelu', 'tanh', 'sigmoid', 'swish', 'mish', 'sqrt', + 'hard_swish', 'sigmoid', 'abs', 'relu6', 'clip', 'tanh', + 'hard_sigmoid', 'leaky_relu' + ])) + + def generate_input(type): + if transpose_X and transpose_Y: + shape_x = [batch_size, channel, input_dim, 32] + shape_y = [batch_size, channel, 64, input_dim] + elif transpose_X: + shape_x = [batch_size, channel, input_dim, 32] + shape_y = [batch_size, channel, input_dim, 64] + elif transpose_Y: + shape_x = [batch_size, channel, 32, input_dim] + shape_y = [batch_size, channel, 8, input_dim] + else: + shape_x = [batch_size, channel, 32, input_dim] + shape_y = [batch_size, channel, input_dim, 16] + + if type == 'x': + return np.random.random(shape_x).astype(np.float32) + else: + return np.random.random(shape_y).astype(np.float32) + + matmul_op = OpConfig(type='matmul', + inputs={ + 'X': ['matmul_X'], + 'Y': ['matmul_Y'] + }, + outputs={'Out': ['matmul_output']}, + attrs={ + 'transpose_X': transpose_X, + 'transpose_Y': transpose_Y, + 'alpha': alpha + }) + + if activation_type == "relu6": + activation_op = OpConfig(activation_type, + inputs={"X": ["matmul_output"]}, + outputs={"Out": ["activation_output"]}, + threshold=draw( + st.floats(min_value=1.0, + max_value=10.0))) + elif activation_type == "leaky_relu": + activation_op = OpConfig(activation_type, + inputs={"X": ["matmul_output"]}, + outputs={"Out": ["activation_output"]}, + alpha=draw( + st.floats(min_value=0.1, + max_value=1.0))) + elif activation_type == "swish": + activation_op = OpConfig(activation_type, + inputs={"X": ["matmul_output"]}, + outputs={"Out": ["activation_output"]}, + beta=draw( + st.floats(min_value=0.1, + max_value=1.0))) + elif activation_type == "clip": + activation_op = OpConfig( + activation_type, + inputs={"X": ["matmul_output"]}, + outputs={"Out": ["activation_output"]}, + min=draw(st.floats(min_value=0.1, max_value=0.49)), + max=draw(st.floats(min_value=0.5, max_value=1.0))) + else: + activation_op = OpConfig(activation_type, + inputs={"X": ["matmul_output"]}, + outputs={"Out": ["activation_output"]}) + + model_net = [matmul_op, activation_op] + + program_config = ProgramConfig( + ops=model_net, + weights={}, + inputs={ + 'matmul_X': TensorConfig(data_gen=partial(generate_input, 'x')), + 'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'y')) + }, + outputs=['activation_output']) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, ['matmul'], (1e-5, 1e-5) + + def test(self): + self.run_and_statis(quant=False, + max_examples=30, + passes=['matmul_activation_mkldnn_fuse_pass']) + + +if __name__ == '__main__': + unittest.main() -- GitLab