From 3f219160bee15a3afa7107439197361f8266dc57 Mon Sep 17 00:00:00 2001 From: Tomasz Socha Date: Mon, 14 Mar 2022 12:53:02 +0100 Subject: [PATCH] Add an elementwise + activation fusion pass. (#36541) * Add elementwise add and activation fuse pass * Fix copy ellision * More flexible pattern detector * More flexible fusion pass * Update lists for pass * Add support for Pow operator * Add support for more activation types * Style * Rename fusion pass * First version of tests * Dirty version of pass * Polished version * Update pbtxt * Style * Update names * Style * Use PADDLE_ENFORCE_EQ * Save error message to variable * WO for error checks * CR * Static style check * Add missing 'activation_scale' attribute * Add relu6 and sigmoid activations * Style * Fix fuse list formating * Sync filenames for fuse pass files * Fix cmake after move * Fix registration * Fix pass name in tests * Add missing activations to checker * WIPS * Working mul op * Working sub * Working Add * Remove pten includes * Remove some forward declarations * Remove Includes * Fixes * Remove default kernels * Add check if post_ops attributes are avaliable * Style * Code adjustment * Register default kernels * We have year 2022 not 2021... Co-authored-by: jakpiase Co-authored-by: Sylwester Fraczek * Fast review fixes Co-authored-by: jakpiase Co-authored-by: Sylwester Fraczek * Review Fix * Rename one_dnn -> onednn * Style after review * Fast and dirty fix for quantization * Update tests * Style * Fix mkldnn_quantizer config * Add Joanna's suggestion. * Check if operator is explicitly disables on OneDNN * Try to use unregistered attributes * Style * Test new framework * FXI * FXII * Update test * Style Co-authored-by: jakpiase Co-authored-by: Sylwester Fraczek --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/graph_pattern_detector.cc | 30 ++ .../framework/ir/graph_pattern_detector.h | 22 ++ .../ir/mkldnn/elt_act_mkldnn_fuse_pass.cc | 145 ++++++++ .../ir/mkldnn/elt_act_mkldnn_fuse_pass.h | 44 +++ .../inference/api/paddle_pass_builder.cc | 1 + .../mkldnn/elementwise_mkldnn_op.h | 45 ++- paddle/fluid/platform/mkldnn_reuse.h | 14 +- .../test_mkldnn_elt_act_fuse_pass.py | 328 ++++++++++++++++++ .../test_mkldnn_elt_act_fuse_pass_new.py | 82 +++++ 10 files changed, 702 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_elt_act_fuse_pass.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_elt_act_fuse_pass_new.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a1f2d6edca6..623c8a048c2 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -126,6 +126,7 @@ if(WITH_MKLDNN) pass_library(interpolate_mkldnn_pass inference DIR mkldnn) pass_library(softplus_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn) + pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index d7d866fa98b..18068e22b7f 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -918,6 +918,36 @@ PDNode *patterns::ConvActivation::operator()( return activation_out_var; } +PDNode *patterns::ElementwiseActivation::operator()( + paddle::framework::ir::PDNode *elementwise_a, + const std::string &elementwise_type, const std::string &activation_type) { + // Create Operators + elementwise_a->assert_is_op_input(elementwise_type, "X"); + auto *elementwise_op = + pattern->NewNode(elementwise_repr())->assert_is_op(elementwise_type); + auto *activation_op = + pattern->NewNode(activation_repr())->assert_is_op(activation_type); + // Create variables + auto *elementwise_b = pattern->NewNode(elementwise_b_repr()) + ->AsInput() + ->assert_is_op_input(elementwise_type, "Y"); + // intermediate variable, will be removed in the IR after fuse. + auto *elementwise_out_var = + pattern->NewNode(elementwise_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op(elementwise_type) + ->assert_is_op_input(activation_type); + // output + auto *activation_out_var = pattern->NewNode(activation_out_repr()) + ->AsOutput() + ->assert_is_op_output(activation_type); + + elementwise_op->LinksFrom({elementwise_a, elementwise_b}) + .LinksTo({elementwise_out_var}); + activation_op->LinksFrom({elementwise_out_var}).LinksTo({activation_out_var}); + return activation_out_var; +} + PDNode *patterns::SeqConvEltAddRelu::operator()( paddle::framework::ir::PDNode *seqconv_input) { // Create Operators diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 0f21906d08d..062d2f9dedc 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -487,6 +487,28 @@ struct ConvActivation : public PatternBase { PATTERN_DECL_NODE(activation_out); }; +// Elementwise with Activation +// op: elementwise + activation +// named nodes: +// elementwise_a, elementwise_b, +// elementwise_out, elementwise, +// activation_out, activation +struct ElementwiseActivation : public PatternBase { + ElementwiseActivation(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "elementwise_add_activation") {} + + PDNode* operator()(PDNode* elementwise_a, const std::string& elementwise_type, + const std::string& activation_type); + + // declare operator node's name + PATTERN_DECL_NODE(elementwise); + PATTERN_DECL_NODE(activation); + // declare variable node's name + PATTERN_DECL_NODE(elementwise_b); + PATTERN_DECL_NODE(elementwise_out); + PATTERN_DECL_NODE(activation_out); +}; + // SEQCONV with Elementwise_Add ReLU // op: seqconv + elementwise_add + relu // named nodes: diff --git a/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc new file mode 100644 index 00000000000..b7f7a8071d2 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const { + std::vector act_types = { + "relu", "tanh", "leaky_relu", "swish", "hardswish", "sqrt", + "abs", "clip", "gelu", "relu6", "sigmoid"}; + std::vector elt_types = {"elementwise_add", "elementwise_sub", + "elementwise_mul"}; + + for (const auto &elt_type : elt_types) + for (const auto &act_type : act_types) { + std::unordered_map attr_map; + + if (act_type == "swish") + attr_map.emplace("beta", "activation_alpha"); + else if (act_type == "relu6") + attr_map.emplace("threshold", "activation_alpha"); + else if (act_type == "clip") { + attr_map.emplace("min", "activation_alpha"); + attr_map.emplace("max", "activation_beta"); + } else { + attr_map.emplace("alpha", "activation_alpha"); + attr_map.emplace("beta", "activation_beta"); + } + FuseElementwiseAct(graph, elt_type, act_type, attr_map); + } +} + +void ElementwiseActivationOneDNNPass::FuseElementwiseAct( + Graph *graph, const std::string &elt_type, const std::string &act_type, + const std::unordered_map &attr_map) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init("elementwise_act", graph); + + GraphPatternDetector gpd; + auto *elementwise_input = gpd.mutable_pattern() + ->NewNode(elt_type + "_act/elementwise_input") + ->AsInput() + ->assert_is_op_input(elt_type, "X"); + patterns::ElementwiseActivation elementwise_act_pattern(gpd.mutable_pattern(), + elt_type + "_act"); + elementwise_act_pattern(elementwise_input, elt_type, act_type); + + int found_elementwise_activation_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "Fuse " << elt_type << " with activation op."; + // Elementwise output + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, + elementwise_act_pattern); + // ACT output + GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, + elementwise_act_pattern); + // ops + GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, + elementwise_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(activation, activation, elementwise_act_pattern); + + auto *elementwise_op = elementwise->Op(); + + if (elementwise_op->HasAttr("use_mkldnn")) { + const std::string wo_elt_type = + "The " + elt_type; // Workaround for PP error message checking. + PADDLE_ENFORCE_EQ( + BOOST_GET_CONST(bool, elementwise_op->GetAttr("use_mkldnn")), true, + platform::errors::PreconditionNotMet( + wo_elt_type + "+Act fusion may happen only when oneDNN library " + "is used.")); + } + + auto *activation_op = activation->Op(); + for (const auto &attr : attr_map) { + if (activation_op->HasAttr(attr.first)) { + elementwise_op->SetAttr(attr.second, + activation_op->GetAttr(attr.first)); + } + } + + if (act_type == "gelu" && activation_op->HasAttr("approximate") && + BOOST_GET_CONST(bool, activation_op->GetAttr("approximate"))) + elementwise_op->SetAttr("activation_type", std::string("gelu_tanh")); + else + elementwise_op->SetAttr("activation_type", act_type); + + elementwise_op->SetOutput("Out", {activation_out->Name()}); + + IR_OP_VAR_LINK(elementwise, activation_out); + GraphSafeRemoveNodes(g, {activation, elementwise_out}); + found_elementwise_activation_count++; + }; + + gpd(graph, handler); + AddStatis(found_elementwise_activation_count); + PrettyLogDetail("--- fused %d %s with %s activation", + found_elementwise_activation_count, elt_type, act_type); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(elt_act_mkldnn_fuse_pass, + paddle::framework::ir::ElementwiseActivationOneDNNPass); +REGISTER_PASS_CAPABILITY(elt_act_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .LE("elementwise_sub", 1) + .LE("elementwise_mul", 1) + .LE("relu", 0) + .LE("tanh", 0) + .LE("leaky_relu", 1) + .LE("swish", 0) + .LE("hard_swish", 0) + .LE("sqrt", 0) + .LE("abs", 0) + .LE("clip", 1) + .LE("gelu", 0) + .LE("relu6", 0) + .LE("sigmoid", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h new file mode 100644 index 00000000000..b8b7d06a828 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * \brief Fuse the Elementwise and activation operators into single + * OneDNN's Elementwise with post-op. + */ +class ElementwiseActivationOneDNNPass : public FusePassBase { + public: + virtual ~ElementwiseActivationOneDNNPass() {} + + protected: + void ApplyImpl(Graph *graph) const override; + + void FuseElementwiseAct( + Graph *graph, const std::string &elt_types, const std::string &act_types, + const std::unordered_map &attr_map) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index f5f36d805b4..22d9dedb32e 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -262,6 +262,7 @@ void CpuPassStrategy::EnableMKLDNN() { // "fc_act_mkldnn_fuse_pass", "batch_norm_act_fuse_pass", // "softplus_activation_mkldnn_fuse_pass", // + "elt_act_mkldnn_fuse_pass", // // TODO(intel): Please fix the bug on windows. // https://github.com/PaddlePaddle/Paddle/issues/29710 // "mkldnn_inplace_pass", // This pass should be activated after diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h index 763fc5f2674..ad8fd317013 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -32,6 +32,45 @@ using dnnl::stream; template class EltwiseMKLDNNKernel : public framework::OpKernel { + private: + dnnl::post_ops get_post_ops(const framework::ExecutionContext& ctx) const { + dnnl::post_ops post_operations; + if (ctx.HasAttr("activation_type")) { + const float scale = ctx.HasAttr("activation_scale") + ? ctx.Attr("activation_scale") + : 1.0f; + const float alpha = ctx.HasAttr("activation_alpha") + ? ctx.Attr("activation_alpha") + : 0.0f; + const float beta = ctx.HasAttr("activation_beta") + ? ctx.Attr("activation_beta") + : 0.0f; + + static std::unordered_map algo_map = { + {"relu", dnnl::algorithm::eltwise_relu}, + {"tanh", dnnl::algorithm::eltwise_tanh}, + {"leaky_relu", dnnl::algorithm::eltwise_relu}, + {"swish", dnnl::algorithm::eltwise_swish}, + {"hardswish", dnnl::algorithm::eltwise_hardswish}, + {"sqrt", dnnl::algorithm::eltwise_sqrt}, + {"abs", dnnl::algorithm::eltwise_abs}, + {"clip", dnnl::algorithm::eltwise_clip}, + {"gelu", dnnl::algorithm::eltwise_gelu_erf}, + {"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh}, + {"relu6", dnnl::algorithm::eltwise_bounded_relu}, + {"sigmoid", dnnl::algorithm::eltwise_logistic}}; + + const auto& activation_type = + algo_map.find(ctx.Attr("activation_type")); + + if (activation_type != algo_map.end()) { + post_operations.append_eltwise(scale, activation_type->second, alpha, + beta); + } + } + return post_operations; + } + public: void Compute(const framework::ExecutionContext& ctx) const override { const auto& dev_ctx = @@ -47,9 +86,9 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { float scale_o = ctx.Attr("Scale_out"); int axis = ctx.Attr("axis"); - platform::BinaryMKLDNNHandler handler(BINARY_OP, axis, mkldnn_engine, - ctx.GetPlace(), x, y, z, scale_x, - scale_y, scale_o); + platform::BinaryMKLDNNHandler handler( + BINARY_OP, axis, mkldnn_engine, ctx.GetPlace(), x, y, z, scale_x, + scale_y, scale_o, get_post_ops(ctx)); const auto src_x_memory = handler.AcquireSrcMemory(x); const auto src_y_memory = handler.AcquireSecondSrcMemory(y); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 01de7349f48..1254331835b 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -618,7 +618,7 @@ class BinaryMKLDNNHandler const dnnl::engine engine, platform::Place cpu_place, const Tensor* x, const Tensor* y, Tensor* z, float scale_x, float scale_y, float scale_z, - const dnnl::post_ops& post_ops = dnnl::post_ops()) + const dnnl::post_ops& post_ops = dnnl::post_ops{}) : platform::MKLDNNHandlerNoCachingT(engine, cpu_place) { PADDLE_ENFORCE_EQ( x->layout(), DataLayout::kMKLDNN, @@ -676,8 +676,8 @@ class BinaryMKLDNNHandler const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType(), MKLDNNMemoryFormat::any); - auto attributes = CreateAttributes(algo, scale_x, scale_y, scale_z); - attributes.set_post_ops(post_ops); + auto attributes = + CreateAttributes(algo, scale_x, scale_y, scale_z, post_ops); this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md, dst_md); @@ -690,10 +690,9 @@ class BinaryMKLDNNHandler } private: - static inline dnnl::primitive_attr CreateAttributes(dnnl::algorithm op, - float scale_x, - float scale_y, - float scale_z) { + static inline dnnl::primitive_attr CreateAttributes( + dnnl::algorithm op, float scale_x, float scale_y, float scale_z, + dnnl::post_ops post_ops = dnnl::post_ops{}) { // Scales set in attributes for inputs contibute to the output equation // in the following way (assuming no broadcasting takes place): // output_i = scale_0 * x_i <+ or *> scale_1 * y_i; @@ -718,6 +717,7 @@ class BinaryMKLDNNHandler {scale_0}); attributes.set_scales(/* input_y_id = */ DNNL_ARG_SRC_1, /* mask = */ 0, {scale_1}); + if (post_ops.len() > 0) attributes.set_post_ops(post_ops); return attributes; } }; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_elt_act_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_elt_act_fuse_pass.py new file mode 100644 index 00000000000..893bd383343 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_elt_act_fuse_pass.py @@ -0,0 +1,328 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle +import paddle.fluid as fluid +from paddle.fluid.core import PassVersionChecker + + +class ElementwiseActivationMkldnnFusePassTest(InferencePassTest): + act_alpha = None + act_beta = None + pass_name = 'elt_act_mkldnn_fuse_pass' + + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data_A = fluid.data( + name="data_A", shape=[-1, 3, 100, 100], dtype="float32") + data_B = fluid.data( + name="data_B", shape=[-1, 3, 100, 100], dtype="float32") + elt_out = self.operand(data_A, data_B) + if self.act is not None: + if self.act_beta is not None: + elt_out = self.act(elt_out, self.act_alpha, self.act_beta) + elif self.act_alpha is not None: + elt_out = self.act(elt_out, self.act_alpha) + else: + elt_out = self.act(elt_out) + + self.feeds = { + "data_A": np.random.random((1, 3, 100, 100)).astype("float32"), + "data_B": np.random.random((1, 3, 100, 100)).astype("float32") + } + self.fetch_list = [elt_out] + self.enable_mkldnn = True + + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act = None + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + def test_pass_compatible(self): + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +class ElementwiseActivationMkldnnFusePassTest_Add_Relu( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act = fluid.layers.relu + + +class ElementwiseActivationMkldnnFusePassTest_Add_Tanh( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act = fluid.layers.tanh + + +class ElementwiseActivationMkldnnFusePassTest_Add_LeakyRelu( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act_alpha = 0.2 + self.act = fluid.layers.leaky_relu + + +class ElementwiseActivationMkldnnFusePassTest_Add_Swish( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act_alpha = 4 + self.act = fluid.layers.swish + + +class ElementwiseActivationMkldnnFusePassTest_Add_HardSwish( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act = fluid.layers.hard_swish + + +class ElementwiseActivationMkldnnFusePassTest_Add_SQRT( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act = fluid.layers.sqrt + + +class ElementwiseActivationMkldnnFusePassTest_Add_ABS( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act = fluid.layers.abs + + +class ElementwiseActivationMkldnnFusePassTest_Add_Clip( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act = fluid.layers.clip + self.act_alpha = 0.0 + self.act_beta = 10.0 + + +class ElementwiseActivationMkldnnFusePassTest_Add_Gelu( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act = fluid.layers.gelu + + +class ElementwiseActivationMkldnnFusePassTest_Add_Gelu_Tanh( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act = fluid.layers.gelu + self.act_alpha = True + + +class ElementwiseActivationMkldnnFusePassTest_Add_Relu6( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act = fluid.layers.relu6 + self.act_alpha = 5.0 + + +class ElementwiseActivationMkldnnFusePassTest_Add_Sigmoid( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_add + self.act = fluid.layers.sigmoid + + +class ElementwiseActivationMkldnnFusePassTest_Sub_Relu( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_sub + self.act = fluid.layers.relu + + +class ElementwiseActivationMkldnnFusePassTest_Sub_Tanh( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_sub + self.act = fluid.layers.tanh + + +class ElementwiseActivationMkldnnFusePassTest_Sub_LeakyRelu( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_sub + self.act_alpha = 0.2 + self.act = fluid.layers.leaky_relu + + +class ElementwiseActivationMkldnnFusePassTest_Sub_Swish( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_sub + self.act = fluid.layers.swish + + +class ElementwiseActivationMkldnnFusePassTest_Sub_HardSwish( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_sub + self.act = fluid.layers.hard_swish + + +class ElementwiseActivationMkldnnFusePassTest_Sub_ABS( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_sub + self.act = fluid.layers.abs + + +class ElementwiseActivationMkldnnFusePassTest_Sub_Clip( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_sub + self.act = fluid.layers.clip + self.act_alpha = 0.0 + self.act_beta = 10.0 + + +class ElementwiseActivationMkldnnFusePassTest_Sub_Gelu( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_sub + self.act = fluid.layers.gelu + + +class ElementwiseActivationMkldnnFusePassTest_Sub_Gelu_Tanh( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_sub + self.act = fluid.layers.gelu + self.act_alpha = True + + +class ElementwiseActivationMkldnnFusePassTest_Sub_Relu6( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_sub + self.act = fluid.layers.relu6 + self.act_alpha = 5.0 + + +class ElementwiseActivationMkldnnFusePassTest_Sub_Sigmoid( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_sub + self.act = fluid.layers.sigmoid + + +class ElementwiseActivationMkldnnFusePassTest_Mul_Relu( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act = fluid.layers.relu + + +class ElementwiseActivationMkldnnFusePassTest_Mul_Tanh( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act = fluid.layers.tanh + + +class ElementwiseActivationMkldnnFusePassTest_Mul_LeakyRelu( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act_alpha = 0.2 + self.act = fluid.layers.leaky_relu + + +class ElementwiseActivationMkldnnFusePassTest_Mul_Swish( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act = fluid.layers.swish + + +class ElementwiseActivationMkldnnFusePassTest_Mul_HardSwish( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act = fluid.layers.hard_swish + + +class ElementwiseActivationMkldnnFusePassTest_Mul_SQRT( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act = fluid.layers.sqrt + + +class ElementwiseActivationMkldnnFusePassTest_Mul_ABS( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act = fluid.layers.abs + + +class ElementwiseActivationMkldnnFusePassTest_Mul_Clip( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act = fluid.layers.clip + self.act_alpha = 0.0 + self.act_beta = 10.0 + + +class ElementwiseActivationMkldnnFusePassTest_Mul_Gelu( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act = fluid.layers.gelu + + +class ElementwiseActivationMkldnnFusePassTest_Mul_Gelu_Tanh( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act = fluid.layers.gelu + self.act_alpha = True + + +class ElementwiseActivationMkldnnFusePassTest_Mul_Relu6( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act = fluid.layers.relu6 + self.act_alpha = 5.0 + + +class ElementwiseActivationMkldnnFusePassTest_Mul_Sigmoid( + ElementwiseActivationMkldnnFusePassTest): + def set_params(self): + self.operand = fluid.layers.elementwise_mul + self.act = fluid.layers.sigmoid + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_elt_act_fuse_pass_new.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_elt_act_fuse_pass_new.py new file mode 100644 index 00000000000..0f5279b0eda --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_elt_act_fuse_pass_new.py @@ -0,0 +1,82 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume +import hypothesis.strategies as st + + +class TestElementWiseAddReluFusePass(PassAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_config(self, draw): + batch_size = draw(st.integers(min_value=1, max_value=4)) + + def generate_input(): + return np.random.random( + [batch_size, 3, 100, 100]).astype(np.float32) + + ops_config = [{ + "op_type": "elementwise_add", + "op_inputs": { + "X": ["A"], + "Y": ["B"] + }, + "op_outputs": { + "Out": ["add_output"] + }, + "op_attrs": {} + }, { + "op_type": "relu", + "op_inputs": { + "X": ["add_output"] + }, + "op_outputs": { + "Out": ["relu_output"] + }, + "op_attrs": {} + }] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "A": TensorConfig(data_gen=partial(generate_input)), + "B": TensorConfig(data_gen=partial(generate_input)) + }, + outputs=["relu_output"]) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, ["elementwise_add"], (1e-5, 1e-5) + + def test(self): + self.run_and_statis( + quant=False, passes=["elt_act_mkldnn_fuse_pass"], min_success_num=4) + + +if __name__ == "__main__": + unittest.main() -- GitLab