diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 680ae9a681a0b8474591dbf5b2f9d9e484d2fcca..0b5af21ca5c467935331e596bec41339972129d8 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -212,6 +212,7 @@ if(WITH_MKLDNN) pass_library(shuffle_channel_mkldnn_detect_pass inference DIR mkldnn) pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn) + pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index cce1ec89a2e82e3e4a11a7ebc85823f3663c972a..85b3bdb874d4f5ebf15d10d9998e18ed90f6945d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2018,6 +2018,33 @@ PDNode *patterns::ElementwiseOp::operator()( return out_var; } +PDNode *patterns::MatmulElementwiseAdd::operator()( + const std::string &matmul_type, bool as_x) { + auto matmul_op = + pattern->NewNode(matmul_op_repr())->assert_is_op(matmul_type); + auto matmul_out = + pattern->NewNode(matmul_out_repr()) + ->AsIntermediate() + ->assert_is_op_output(matmul_type, "Out") + ->assert_is_only_output_of_op(matmul_type) + ->assert_is_op_input("elementwise_add", as_x ? "X" : "Y"); + auto elementwise_addend = + pattern->NewNode(elementwise_addend_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", as_x ? "Y" : "X"); + auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr()) + ->assert_is_op("elementwise_add"); + auto elementwise_add_out = + pattern->NewNode(elementwise_add_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_add", "Out"); + + matmul_op->LinksTo({matmul_out}); + elementwise_add_op->LinksFrom({matmul_out, elementwise_addend}) + .LinksTo({elementwise_add_out}); + return elementwise_add_out; +} + PDNode *patterns::ResidualElementwise::operator()( PDNode *op_var, PDNode *residual_var, diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 794c25e85a555fd93f750222afeba8d3896e289a..f0f7282683b710519256b0d10c94628dcf6d676d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1038,6 +1038,21 @@ struct ElementwiseOp : public PatternBase { PATTERN_DECL_NODE(elementwise_out); }; +struct MatmulElementwiseAdd : public PatternBase { + MatmulElementwiseAdd(PDPattern* pattern, + const std::string& name_scope, + const std::string& matmul_type, + bool as_x) + : PatternBase(pattern, name_scope, "matmul_elementwise_add") {} + + PDNode* operator()(const std::string& matmul_type, bool as_x); + PATTERN_DECL_NODE(matmul_op); + PATTERN_DECL_NODE(matmul_out); + PATTERN_DECL_NODE(elementwise_addend); + PATTERN_DECL_NODE(elementwise_add_op); + PATTERN_DECL_NODE(elementwise_add_out); +}; + // Residual Elementwise ops // This pattern allows operator output to be X or Y // and residual data Y or X, based on as_x flag diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc index f1c1b57f3f662a61c3824a80f5556070ef866c7f..d64fbe16a3eb472593e11e12136dfec808a2c20f 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.cc @@ -199,8 +199,11 @@ class DeQuantizer final : public Quanter { bool IsNotPermittedName(const std::string& output_name) const override { std::unordered_map> block_list{ {"layer_norm", - {"Mean", "Variance"}}, // not used in inference in MKLDNN - {"fc", {"ResidualData"}}}; // artifical output, already dequantized + {"Mean", "Variance"}}, // not used in inference in MKLDNN + {"fc", {"ResidualData"}}, // artifical output, already dequantized + {"matmul", {"ResidualData"}}, // artifical output, already dequantized + {"matmul_v2", + {"ResidualData"}}}; // artifical output, already dequantized std::vector blocked_outputs{"XShape"}; // blocklist for any op auto op_name = op->Name(); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc index d3f71e498bfe846153d722c4e5d17c6c8d9d8115..9ba89106c3471e60f60899f7d8c2e2fdaa4228a8 100644 --- a/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc @@ -26,7 +26,7 @@ using string::PrettyLogDetail; void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { auto act_types = paddle::platform::GetSupportedActivations(); - std::vector matmul_types = {"matmul"}; + auto matmul_types = {"matmul", "matmul_v2"}; for (const auto& matmul_type : matmul_types) for (auto& act_type : act_types) { @@ -88,8 +88,9 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct( gpd(graph, handler); AddStatis(found_matmul_activation_count); if (!Has("disable_logs") || !Get("disable_logs")) { - PrettyLogDetail("--- fused %d matmul with %s activation", + PrettyLogDetail("--- fused %d %s with %s activation", found_matmul_activation_count, + matmul_type, act_type); } } @@ -102,6 +103,11 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() { .AddInput("Y") .IsTensor() .End() + .AddInput( + "ResidualData") // Extra tensor used in matmul+elementwise_add fuse + .IsTensor() + .IsOptional() + .End() .AddOutput("Out") .IsTensor() .End() @@ -115,6 +121,28 @@ MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() { .IsType() .End(); + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddInput( + "ResidualData") // Extra tensor used in matmul+elementwise_add fuse + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + AddOpCompat(OpCompat("abs")) .AddInput("X") .IsTensor() @@ -267,6 +295,7 @@ REGISTER_PASS_CAPABILITY(matmul_activation_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("matmul", 1) + .EQ("matmul_v2", 0) .EQ("abs", 0) .LE("clip", 1) .EQ("gelu", 0) diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e6e450cd4c72324e036c78586b10f55f5dfc83c --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.cc @@ -0,0 +1,157 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.h" + +#include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void MatmulElementwiseAddMKLDNNFusePass::ApplyImpl(Graph* graph) const { + auto matmul_types = {"matmul", "matmul_v2"}; + auto matmul_as_x = {true, false}; + + for (const auto& matmul_type : matmul_types) + for (const auto& as_x : matmul_as_x) { + FuseMatmulElementwiseAdd(graph, matmul_type, as_x); + } +} + +void MatmulElementwiseAddMKLDNNFusePass::FuseMatmulElementwiseAdd( + Graph* graph, const std::string& matmul_type, bool matmul_as_x) const { + const std::string fusion_mode = matmul_as_x ? "x" : "y"; + const auto name_scope = matmul_type + "_elementwise_add_as_" + fusion_mode; + FusePassBase::Init(name_scope, graph); + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + patterns::MatmulElementwiseAdd matmul_pattern( + pattern, name_scope, matmul_type, matmul_as_x); + matmul_pattern(matmul_type, matmul_as_x); + + int found_matmul_elementwise_add_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(matmul, matmul_op, matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise_add, elementwise_add_op, matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise_addend, elementwise_addend, matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise_add_out, elementwise_add_out, matmul_pattern); + + if (FindFuseOption(*matmul, *elementwise_add) != FUSE_MKLDNN) return; + if (!IsCompat(subgraph, g)) { + LOG(WARNING) + << "op compat for matmul_elementwise_add_mkldnn_fuse_pass failed."; + return; + } + if (matmul->Op()->HasAttr("ResidualData")) { + LOG(WARNING) << "matmul_elementwise_add can be fused once"; + return; + } + + matmul->Op()->SetInput("ResidualData", {elementwise_addend->Name()}); + matmul->Op()->SetOutput("Out", {elementwise_add_out->Name()}); + + GraphSafeRemoveNodes(g, {matmul_out, elementwise_add}); + + IR_NODE_LINK_TO(elementwise_addend, matmul); + IR_NODE_LINK_TO(matmul, elementwise_add_out); + + found_matmul_elementwise_add_count++; + }; + + gpd(graph, handler); + AddStatis(found_matmul_elementwise_add_count); + if (!Has("disable_logs") || !Get("disable_logs")) { + PrettyLogDetail("--- fused %d %s (as %s) with elementwise_add", + found_matmul_elementwise_add_count, + matmul_type, + fusion_mode); + } +} + +MatmulElementwiseAddMKLDNNFusePass::MatmulElementwiseAddMKLDNNFusePass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsType() + .End() + .AddAttr("transpose_X") + .IsType() + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 0, 1}) + .End(); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(matmul_elementwise_add_mkldnn_fuse_pass, + paddle::framework::ir::MatmulElementwiseAddMKLDNNFusePass); +REGISTER_PASS_CAPABILITY(matmul_elementwise_add_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("matmul", 1) + .EQ("matmul_v2", 0) + .LE("elementwise_add", 1)); diff --git a/paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..c630fd0b8741e37ae2df4671965598979113c82d --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/matmul_elementwise_add_mkldnn_fuse_pass.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class MatmulElementwiseAddMKLDNNFusePass : public FusePassBase { + public: + MatmulElementwiseAddMKLDNNFusePass(); + virtual ~MatmulElementwiseAddMKLDNNFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const; + void FuseMatmulElementwiseAdd(Graph* graph, + const std::string& matmul_type, + bool matmul_as_x) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 52b5d524495812bdd07aa577789c9227beabf7e4..235fd99535fa0d69dedd4f0da33c727c90825e73 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -305,6 +305,8 @@ void CpuPassStrategy::EnableMKLDNN() { "reshape_transpose_matmul_v2_mkldnn_fuse_pass", // "matmul_transpose_reshape_fuse_pass", // "matmul_v2_transpose_reshape_fuse_pass", // + "matmul_elementwise_add_mkldnn_fuse_pass", // + "matmul_activation_mkldnn_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass", // "fc_act_mkldnn_fuse_pass", @@ -313,7 +315,6 @@ void CpuPassStrategy::EnableMKLDNN() { "softplus_activation_mkldnn_fuse_pass", // "shuffle_channel_mkldnn_detect_pass", // "elt_act_mkldnn_fuse_pass", // - "matmul_activation_mkldnn_fuse_pass", // // TODO(intel): Please fix the bug on windows. // https://github.com/PaddlePaddle/Paddle/issues/29710 // "mkldnn_inplace_pass", // This pass should be activated after diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 5c3dd0cb1234aff810cbb480c9ee40db37eb6363..02632673b958437c8f5115decce106567187bc2d 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -446,7 +446,6 @@ class MatMulMKLDNNHandler if (scale_out != 1.0f) { matmul_attrs.set_output_scales(0, {scale_out}); } - paddle::platform::AppendActivation(ctx, post_operations); matmul_attrs.set_post_ops(post_operations); @@ -698,6 +697,13 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, {DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; + if (ctx.HasInput("ResidualData")) { + auto *residual_data = ctx.Input("ResidualData"); + const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data); + matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, + *residual_data_memory_p}); + } + auto &astream = MKLDNNDeviceContext::tls().get_stream(); matmul_p->execute(astream, matmul_args); astream.wait(); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 85b0775c751dc0257bfc753f572b668da5add609..6c802b682ec5f24cbd697ccef4edefd495db9b90 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -958,6 +958,16 @@ class MatMulV2MKLDNNHandler matmul_attrs.set_output_scales(0, {alpha}); } + if (ctx.HasInput("ResidualData")) { + auto* residual_data = ctx.Input("ResidualData"); + auto residual_data_tz = phi::vectorize(residual_data->dims()); + auto residual_data_md = memory::desc(residual_data_tz, + dnnl::memory::data_type::f32, + dnnl::memory::format_tag::abcd); + post_operations.append_binary(dnnl::algorithm::binary_add, + residual_data_md); + } + AppendActivation(ctx, post_operations); matmul_attrs.set_post_ops(post_operations); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_elementwise_add_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_elementwise_add_activation_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..a27ed9dd9c99a2d7000e2d634165c97904a5e81a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_elementwise_add_activation_fuse_pass.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest + +import hypothesis.strategies as st + + +class TestMatmulElementwiseAddActivationMkldnnFusePass(PassAutoScanTest): + + def sample_program_config(self, draw): + axis = draw(st.sampled_from([-1, 0, 1])) + matmul_as_x = draw(st.booleans()) + batch_size = draw(st.integers(min_value=2, max_value=4)) + channel = draw(st.sampled_from([16, 32, 64])) + input_dim = draw(st.sampled_from([16, 32, 64])) + activation_type = draw( + st.sampled_from([ + 'relu', 'gelu', 'tanh', 'sigmoid', 'swish', 'mish', 'sqrt', + 'hard_swish', 'sigmoid', 'abs', 'relu6', 'clip', 'tanh', + 'hard_sigmoid', 'leaky_relu' + ])) + + def generate_input(): + return np.random.random([batch_size, channel, input_dim, + input_dim]).astype(np.float32) + + matmul_op = OpConfig(type='matmul', + inputs={ + 'X': ['matmul_x'], + 'Y': ['matmul_y'] + }, + outputs={'Out': ['matmul_output']}, + attrs={ + 'use_mkldnn': True, + }) + + if matmul_as_x: + inputs = {'X': ['matmul_output'], 'Y': ['elementwise_addend']} + else: + inputs = {'X': ['elementwise_addend'], 'Y': ['matmul_output']} + + elt_add_op = OpConfig(type='elementwise_add', + inputs=inputs, + outputs={'Out': ['elementwise_add_output']}, + attrs={ + 'axis': axis, + 'use_mkldnn': True + }) + + if activation_type == "relu6": + activation_op = OpConfig(activation_type, + inputs={"X": ["elementwise_add_output"]}, + outputs={"Out": ["activation_output"]}, + threshold=draw( + st.floats(min_value=1.0, + max_value=10.0))) + elif activation_type == "leaky_relu": + activation_op = OpConfig(activation_type, + inputs={"X": ["elementwise_add_output"]}, + outputs={"Out": ["activation_output"]}, + alpha=draw( + st.floats(min_value=0.1, + max_value=1.0))) + elif activation_type == "swish": + activation_op = OpConfig(activation_type, + inputs={"X": ["elementwise_add_output"]}, + outputs={"Out": ["activation_output"]}, + beta=draw( + st.floats(min_value=0.1, + max_value=1.0))) + elif activation_type == "clip": + activation_op = OpConfig( + activation_type, + inputs={"X": ["elementwise_add_output"]}, + outputs={"Out": ["activation_output"]}, + min=draw(st.floats(min_value=0.1, max_value=0.49)), + max=draw(st.floats(min_value=0.5, max_value=1.0))) + else: + activation_op = OpConfig(activation_type, + inputs={"X": ["elementwise_add_output"]}, + outputs={"Out": ["activation_output"]}) + + model_net = [matmul_op, elt_add_op, activation_op] + + program_config = ProgramConfig( + ops=model_net, + weights={}, + inputs={ + 'matmul_x': TensorConfig(data_gen=partial(generate_input)), + 'matmul_y': TensorConfig(data_gen=partial(generate_input)), + 'elementwise_addend': + TensorConfig(data_gen=partial(generate_input)) + }, + outputs=['activation_output']) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config( + use_mkldnn=True, + passes=[ + 'matmul_elementwise_add_mkldnn_fuse_pass', + 'matmul_activation_mkldnn_fuse_pass' + ]) + yield config, ['matmul'], (1e-5, 1e-5) + + def test(self): + self.run_and_statis(quant=False, + passes=[ + 'matmul_elementwise_add_mkldnn_fuse_pass', + 'matmul_activation_mkldnn_fuse_pass' + ]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_elementwise_add_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_elementwise_add_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..38c8985dbad1ff2eb43bec9e8c755906d6d0cad1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_elementwise_add_fuse_pass.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest + +import hypothesis.strategies as st + + +class TestMatmulElementwiseAddMkldnnFusePass(PassAutoScanTest): + + def sample_program_config(self, draw): + axis = draw(st.sampled_from([-1, 0, 1])) + matmul_as_x = draw(st.booleans()) + batch_size = draw(st.integers(min_value=2, max_value=4)) + channel = draw(st.sampled_from([16, 32, 64])) + input_dim = draw(st.sampled_from([16, 32, 64])) + + def generate_input(): + return np.random.random([batch_size, channel, input_dim, + input_dim]).astype(np.float32) + + matmul_op = OpConfig(type='matmul', + inputs={ + 'X': ['matmul_x'], + 'Y': ['matmul_y'] + }, + outputs={'Out': ['matmul_output']}, + attrs={ + 'use_mkldnn': True, + }) + + if matmul_as_x: + inputs = {'X': ['matmul_output'], 'Y': ['elementwise_addend']} + else: + inputs = {'X': ['elementwise_addend'], 'Y': ['matmul_output']} + + elt_add_op = OpConfig(type='elementwise_add', + inputs=inputs, + outputs={'Out': ['elementwise_add_output']}, + attrs={ + 'axis': axis, + 'use_mkldnn': True + }) + + model_net = [matmul_op, elt_add_op] + + program_config = ProgramConfig( + ops=model_net, + weights={}, + inputs={ + 'matmul_x': TensorConfig(data_gen=partial(generate_input)), + 'matmul_y': TensorConfig(data_gen=partial(generate_input)), + 'elementwise_addend': + TensorConfig(data_gen=partial(generate_input)) + }, + outputs=['elementwise_add_output']) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config( + use_mkldnn=True, passes=['matmul_elementwise_add_mkldnn_fuse_pass']) + yield config, ['matmul'], (1e-5, 1e-5) + + def test(self): + self.run_and_statis(quant=False, + passes=['matmul_elementwise_add_mkldnn_fuse_pass']) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..2858d7f2d4e33058e08f2edea96153f7d94214c0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_activation_fuse_pass.py @@ -0,0 +1,131 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest +import hypothesis.strategies as st + + +class TestMatmulv2ActivationMkldnnFusePass(PassAutoScanTest): + + def sample_program_config(self, draw): + transpose_X = draw(st.booleans()) + transpose_Y = draw(st.booleans()) + batch_size = draw(st.integers(min_value=2, max_value=4)) + channel = draw(st.sampled_from([16, 32, 64])) + input_dim = draw(st.sampled_from([16, 32, 64])) + activation_type = draw( + st.sampled_from([ + 'relu', 'gelu', 'tanh', 'sigmoid', 'swish', 'mish', 'sqrt', + 'hard_swish', 'sigmoid', 'abs', 'relu6', 'clip', 'tanh', + 'hard_sigmoid', 'leaky_relu' + ])) + + def generate_input(type): + broadcast_X = st.booleans() + channel_X = 1 if broadcast_X else channel + channel_Y = channel if broadcast_X else 1 + batch_size_X = 1 if broadcast_X else batch_size + batch_size_Y = batch_size if broadcast_X else 1 + + if transpose_X and transpose_Y: + shape_x = [batch_size_X, channel_X, input_dim, 32] + shape_y = [batch_size_Y, channel_Y, 64, input_dim] + elif transpose_X: + shape_x = [batch_size_X, channel_X, input_dim, 32] + shape_y = [batch_size_Y, channel_Y, input_dim, 64] + elif transpose_Y: + shape_x = [batch_size_X, channel_X, 32, input_dim] + shape_y = [batch_size_Y, channel_Y, 8, input_dim] + else: + shape_x = [batch_size_X, channel_X, 32, input_dim] + shape_y = [batch_size_Y, channel_Y, input_dim, 16] + + if type == 'X': + return np.random.random(shape_x).astype(np.float32) + else: + return np.random.random(shape_y).astype(np.float32) + + matmul_op = OpConfig(type='matmul_v2', + inputs={ + 'X': ['matmul_X'], + 'Y': ['matmul_Y'] + }, + outputs={'Out': ['matmul_output']}, + attrs={ + 'trans_x': transpose_X, + 'trans_y': transpose_Y + }) + + if activation_type == 'relu6': + activation_op = OpConfig(activation_type, + inputs={'X': ['matmul_output']}, + outputs={'Out': ['activation_output']}, + threshold=draw( + st.floats(min_value=1.0, + max_value=10.0))) + elif activation_type == 'leaky_relu': + activation_op = OpConfig(activation_type, + inputs={'X': ['matmul_output']}, + outputs={'Out': ['activation_output']}, + alpha=draw( + st.floats(min_value=0.1, + max_value=1.0))) + elif activation_type == 'swish': + activation_op = OpConfig(activation_type, + inputs={'X': ['matmul_output']}, + outputs={'Out': ['activation_output']}, + beta=draw( + st.floats(min_value=0.1, + max_value=1.0))) + elif activation_type == 'clip': + activation_op = OpConfig( + activation_type, + inputs={'X': ['matmul_output']}, + outputs={'Out': ['activation_output']}, + min=draw(st.floats(min_value=0.1, max_value=0.49)), + max=draw(st.floats(min_value=0.5, max_value=1.0))) + else: + activation_op = OpConfig(activation_type, + inputs={'X': ['matmul_output']}, + outputs={'Out': ['activation_output']}) + + model_net = [matmul_op, activation_op] + + program_config = ProgramConfig( + ops=model_net, + weights={}, + inputs={ + 'matmul_X': TensorConfig(data_gen=partial(generate_input, 'X')), + 'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'Y')) + }, + outputs=['activation_output']) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, ['matmul_v2'], (1e-5, 1e-5) + + def test(self): + self.run_and_statis(quant=False, + max_examples=30, + passes=['matmul_activation_mkldnn_fuse_pass']) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_elementwise_add_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_elementwise_add_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..03f2867948e916f0aa32d4b3bfee267bfa2d7711 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_elementwise_add_fuse_pass.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest +import hypothesis.strategies as st + + +class TestMatmulV2ElementwiseAddMkldnnFusePass(PassAutoScanTest): + + def sample_program_config(self, draw): + axis = draw(st.sampled_from([-1, 0, 1])) + matmul_as_x = draw(st.booleans()) + batch_size = draw(st.integers(min_value=2, max_value=4)) + channel = draw(st.sampled_from([16, 32, 64])) + input_dim_shared = draw(st.sampled_from([16, 32, 64])) + input_dim_X = draw(st.sampled_from([16, 32, 64])) + input_dim_Y = draw(st.sampled_from([16, 32, 64])) + + def generate_input(type): + broadcast_X = st.booleans() + channel_X = 1 if broadcast_X else channel + channel_Y = channel if broadcast_X else 1 + batch_size_X = 1 if broadcast_X else batch_size + batch_size_Y = batch_size if broadcast_X else 1 + + shape_x = [batch_size_X, channel_X, input_dim_X, input_dim_shared] + shape_y = [batch_size_Y, channel_Y, input_dim_shared, input_dim_Y] + + if type == 'X': + return np.random.random(shape_x).astype(np.float32) + elif type == 'Y': + return np.random.random(shape_y).astype(np.float32) + else: + shape_out = [batch_size, channel, input_dim_X, input_dim_Y] + return np.random.random(shape_out).astype(np.float32) + + matmul_op = OpConfig(type='matmul_v2', + inputs={ + 'X': ['matmul_X'], + 'Y': ['matmul_Y'] + }, + outputs={'Out': ['matmul_output']}, + attrs={'use_mkldnn': True}) + + if matmul_as_x: + inputs = {'X': ['matmul_output'], 'Y': ['elementwise_addend']} + else: + inputs = {'X': ['elementwise_addend'], 'Y': ['matmul_output']} + + elt_add_op = OpConfig(type='elementwise_add', + inputs=inputs, + outputs={'Out': ['elementwise_add_output']}, + attrs={ + 'axis': axis, + 'use_mkldnn': True + }) + + model_net = [matmul_op, elt_add_op] + + program_config = ProgramConfig( + ops=model_net, + weights={}, + inputs={ + 'matmul_X': + TensorConfig(data_gen=partial(generate_input, 'X')), + 'matmul_Y': + TensorConfig(data_gen=partial(generate_input, 'Y')), + 'elementwise_addend': + TensorConfig(data_gen=partial(generate_input, 'ElAdd')) + }, + outputs=['elementwise_add_output']) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, ['matmul_v2'], (1e-5, 1e-5) + + def test(self): + self.run_and_statis(quant=False, + max_examples=30, + passes=['matmul_elementwise_add_mkldnn_fuse_pass']) + + +if __name__ == '__main__': + unittest.main()