From a346c4dcb83f42b0c8d653c086854306e7dbfb17 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 11 Nov 2021 07:33:28 +0100 Subject: [PATCH] Added softplus + activation oneDNN fuse pass (#36657) * added softplus + activation fuse plass * minor change * implemented reviewer suggestion * minor fix * minor fix * added scale_out parameter * minor fix * fix for iScan CI * conditionally disabled logs * refactored pass builder --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/graph_pattern_detector.cc | 21 +++ .../framework/ir/graph_pattern_detector.h | 18 +++ .../softplus_activation_mkldnn_fuse_pass.cc | 133 +++++++++++++++++ .../softplus_activation_mkldnn_fuse_pass.h | 44 ++++++ ...plus_activation_mkldnn_fuse_pass_tester.cc | 101 +++++++++++++ .../inference/api/paddle_pass_builder.cc | 3 +- paddle/fluid/operators/activation_op.cc | 20 +++ .../operators/mkldnn/softplus_mkldnn_op.h | 44 +++++- ...st_mkldnn_softplus_activation_fuse_pass.py | 136 ++++++++++++++++++ 10 files changed, 516 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass_tester.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 80ae0f04daa..49bc8c908c9 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -117,6 +117,7 @@ if(WITH_MKLDNN) pass_library(cpu_bfloat16_pass inference DIR mkldnn) pass_library(fc_mkldnn_pass inference DIR mkldnn) pass_library(interpolate_mkldnn_pass inference DIR mkldnn) + pass_library(softplus_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 6830a1f85e0..dd0ffe8b9fd 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1061,6 +1061,27 @@ PDNode *patterns::FCActOneDNN::operator()(const std::string &act_type) { return act_out; } +PDNode *patterns::SoftplusActivation::operator()(std::string activation_type) { + // Create Operators + auto *softplus_op = + pattern->NewNode(softplus_repr())->assert_is_op("softplus"); + auto *activation_op = + pattern->NewNode(activation_repr())->assert_is_op(activation_type); + // intermediate variable, will be removed in the IR after fuse. + auto *softplus_out = pattern->NewNode(softplus_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("softplus") + ->assert_is_op_input(activation_type); + // output + auto *activation_out = pattern->NewNode(activation_out_repr()) + ->AsOutput() + ->assert_is_op_output(activation_type); + + softplus_op->LinksTo({softplus_out}); + activation_op->LinksFrom({softplus_out}).LinksTo({activation_out}); + return activation_out; +} + PDNode *patterns::Embedding::operator()(PDNode *x) { x->assert_is_op_input("lookup_table", "Ids"); auto *lookup_table_op = diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 6657ab5a6a5..d7bfdc57d1c 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -577,6 +577,24 @@ struct FCActOneDNN : public PatternBase { PATTERN_DECL_NODE(act_out); }; +// Fuse softplus with activation +// ops: softplus + activation +// nodes: +// softplus, softplus_out, +// activation, activation_out +struct SoftplusActivation : public PatternBase { + SoftplusActivation(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "softplus_activation") {} + + PDNode* operator()(std::string activation_type); + + // declare operator node's name + PATTERN_DECL_NODE(softplus); + PATTERN_DECL_NODE(activation); + PATTERN_DECL_NODE(softplus_out); + PATTERN_DECL_NODE(activation_out); +}; + // Embedding struct Embedding : public PatternBase { Embedding(PDPattern* pattern, const std::string& name_scope) diff --git a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc new file mode 100644 index 00000000000..82d642264c2 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using string::PrettyLogDetail; + +void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const { + std::vector act_types = { + "relu", "tanh", "leaky_relu", "swish", "hardswish", "sqrt", + "abs", "clip", "gelu", "relu6", "sigmoid"}; + + for (const auto &act_type : act_types) { + std::unordered_map attr_map; + + if (act_type == "swish") + attr_map.emplace("beta", "fuse_activation_alpha"); + else if (act_type == "relu6") + attr_map.emplace("threshold", "fuse_activation_alpha"); + else if (act_type == "clip") { + attr_map.emplace("min", "fuse_activation_alpha"); + attr_map.emplace("max", "fuse_activation_beta"); + } else { + attr_map.emplace("alpha", "fuse_activation_alpha"); + attr_map.emplace("beta", "fuse_activation_beta"); + } + FuseSoftplusActivation(graph, act_type, attr_map); + } +} + +void SoftplusActivationOneDNNPass::FuseSoftplusActivation( + Graph *graph, const std::string &fuse_activation_type, + const std::unordered_map &attr_map) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init("softplus_activation", graph); + + GraphPatternDetector gpd; + patterns::SoftplusActivation softplus_activation_pattern( + gpd.mutable_pattern(), "softplus_activation"); + softplus_activation_pattern(fuse_activation_type); + + int found_softplus_activation_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "Fuse softplus with activation op."; + GET_IR_NODE_FROM_SUBGRAPH(softplus_out, softplus_out, + softplus_activation_pattern); + GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, + softplus_activation_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(softplus, softplus, softplus_activation_pattern); + GET_IR_NODE_FROM_SUBGRAPH(activation, activation, + softplus_activation_pattern); + + auto *softplus_op = softplus->Op(); + + if (softplus_op->HasAttr("use_mkldnn")) { + PADDLE_ENFORCE_EQ( + BOOST_GET_CONST(bool, softplus_op->GetAttr("use_mkldnn")), true, + platform::errors::PreconditionNotMet("The softplus + activation " + "fusion may happen only when " + "oneDNN library is used.")); + } + + auto *activation_op = activation->Op(); + for (const auto &attr : attr_map) { + if (activation_op->HasAttr(attr.first)) { + softplus_op->SetAttr(attr.second, activation_op->GetAttr(attr.first)); + } + } + + if (fuse_activation_type == "gelu" && + activation_op->HasAttr("approximate") && + BOOST_GET_CONST(bool, activation_op->GetAttr("approximate"))) + softplus_op->SetAttr("fuse_activation_type", std::string("gelu_tanh")); + else + softplus_op->SetAttr("fuse_activation_type", fuse_activation_type); + + softplus_op->SetAttr("use_mkldnn", true); + + softplus_op->SetOutput("Out", {activation_out->Name()}); + + IR_OP_VAR_LINK(softplus, activation_out); + GraphSafeRemoveNodes(g, {activation, softplus_out}); + found_softplus_activation_count++; + }; + + gpd(graph, handler); + AddStatis(found_softplus_activation_count); + if (!Has("disable_logs") || !Get("disable_logs")) + PrettyLogDetail("--- fused %d softplus with %s activation", + found_softplus_activation_count, fuse_activation_type); +} +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(softplus_activation_mkldnn_fuse_pass, + paddle::framework::ir::SoftplusActivationOneDNNPass); +REGISTER_PASS_CAPABILITY(softplus_activation_mkldnn_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("softplus", 1) + .EQ("relu", 0) + .EQ("tanh", 0) + .LE("leaky_relu", 1) + .EQ("swish", 0) + .EQ("hard_swish", 0) + .EQ("sqrt", 0) + .EQ("abs", 0) + .LE("relu6", 1) + .LE("clip", 1) + .EQ("gelu", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h new file mode 100644 index 00000000000..b025ab75a11 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h @@ -0,0 +1,44 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * \brief Fuse the softplus and another activation operators into + * softplus with another activation as post-op. + */ +class SoftplusActivationOneDNNPass : public FusePassBase { + public: + virtual ~SoftplusActivationOneDNNPass() {} + + protected: + void ApplyImpl(ir::Graph *graph) const override; + + void FuseSoftplusActivation( + ir::Graph *graph, const std::string &fuse_activation_type, + const std::unordered_map &attr_map) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass_tester.cc new file mode 100644 index 00000000000..003a39f37d4 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass_tester.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h" + +#include +#include +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +void MainTest(const std::string& activation_type) { + auto prog = + test::BuildProgramDesc({"softplus_x", "softplus_out", "activation_out"}); + test::CreateOp(&prog, "softplus", {{"X", "softplus_x"}}, + {{"Out", "softplus_out"}}); + test::CreateOp(&prog, activation_type, {{"X", "softplus_out"}}, + {{"Out", "activation_out"}}, false); + + Graph graph(prog); + constexpr int removed_nodes_count = 2; + + EXPECT_TRUE(test::RunPassAndAssert( + &graph, "softplus_activation_mkldnn_fuse_pass", "softplus_x", + "activation_out", removed_nodes_count)); + EXPECT_TRUE( + test::AssertOpsCount(graph, {{"softplus", 1}, {activation_type, 0}})); + + for (const auto* node : graph.Nodes()) { + if (node->IsOp() && node->Op()->Type() == "softplus") { + const auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn"))); + ASSERT_TRUE(op->HasAttr("fuse_activation_type")); + auto activation_type = + BOOST_GET_CONST(std::string, op->GetAttr("fuse_activation_type")); + EXPECT_EQ(activation_type.compare(activation_type), 0); + } + } +} + +TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithTanh) { + MainTest("tanh") +} + +TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithRelu) { + MainTest("relu") +} + +TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithLeakyRelu) { + MainTest("leaky_relu") +} + +TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithSwish) { + MainTest("swish") +} + +TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithHardswish) { + MainTest("hardswish") +} + +TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithSqrt) { + MainTest("sqrt") +} + +TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithAbs) { MainTest("abs") } + +TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithClip) { + MainTest("clip") +} + +TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithGelu) { + MainTest("gelu") +} + +TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithRelu6) { + MainTest("relu6") +} + +TEST(FuseSoftplusActivationOneDNNPass, FuseSoftplusWithSigmoid) { + MainTest("sigmoid") +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(softplus_activation_mkldnn_fuse_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 7d867b59e7d..334a70d3e06 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -257,7 +257,8 @@ void CpuPassStrategy::EnableMKLDNN() { // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass", // "fc_act_mkldnn_fuse_pass", - "batch_norm_act_fuse_pass", + "batch_norm_act_fuse_pass", // + "softplus_activation_mkldnn_fuse_pass", // // TODO(intel): Please fix the bug on windows. // https://github.com/PaddlePaddle/Paddle/issues/29710 // "mkldnn_inplace_pass", // This pass should be activated after diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 5e5cd0ea1c5..624ada0dca7 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -447,6 +447,26 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) Only used in cudnn kernel, need install cudnn.") .SetDefault(false) .AsExtra(); + AddAttr( + "fuse_activation_type", + "Fused activation type used in softplus OneDNN kernel.") + .SetDefault("") + .AsExtra(); + AddAttr( + "fuse_activation_alpha", + "Fused activation alpha parameter type used in softplus OneDNN kernel.") + .SetDefault(0.0f) + .AsExtra(); + AddAttr( + "fuse_activation_beta", + "Fused activation beta parameter type used in softplus OneDNN kernel.") + .SetDefault(0.0f) + .AsExtra(); + AddAttr( + "fuse_activation_scale", + "Fused activation scale parameter type used in softplus OneDNN kernel.") + .SetDefault(1.0f) + .AsExtra(); AddComment(R"DOC( :strong:`Softplus Activation Operator` diff --git a/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h b/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h index fdb2c534e03..60ea5136905 100644 --- a/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h +++ b/paddle/fluid/operators/mkldnn/softplus_mkldnn_op.h @@ -23,9 +23,10 @@ template class SoftplusMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT { public: - SoftplusMKLDNNHandler(const Tensor* x, const float beta, - const mkldnn::engine engine, platform::Place cpu_place) - : platform::MKLDNNHandlerNoCachingT(engine, cpu_place) { + SoftplusMKLDNNHandler(const framework::ExecutionContext& ctx, const Tensor* x, + const float beta, const mkldnn::engine engine) + : platform::MKLDNNHandlerNoCachingT(engine, + ctx.GetPlace()) { auto x_tz = framework::vectorize(x->dims()); auto x_md = dnnl::memory::desc(x_tz, platform::MKLDNNGetDataType(), x->format()); @@ -42,6 +43,8 @@ class SoftplusMKLDNNHandler 1.0f / beta, 0.0f); } + AppendFusedActivationIfExists(ctx, post_ops); + dnnl::primitive_attr attrs; attrs.set_post_ops(post_ops); @@ -53,8 +56,41 @@ class SoftplusMKLDNNHandler return this->AcquireMemoryFromPrimitive( this->fwd_pd_->src1_desc(), platform::to_void_cast(beta)); } + + private: + void AppendFusedActivationIfExists(const framework::ExecutionContext& ctx, + dnnl::post_ops& post_ops) { + const auto& fused_activation_type = + algo_map.find(ctx.Attr("fuse_activation_type")); + + if (fused_activation_type != algo_map.end()) { + auto scale_out = + ctx.Attr("fuse_activation_scale"); // for future int8 support + post_ops.append_eltwise(scale_out, fused_activation_type->second, + ctx.Attr("fuse_activation_alpha"), + ctx.Attr("fuse_activation_beta")); + } + } + + static const std::unordered_map algo_map; }; +template +const std::unordered_map + SoftplusMKLDNNHandler::algo_map = { + {"relu", dnnl::algorithm::eltwise_relu}, + {"tanh", dnnl::algorithm::eltwise_tanh}, + {"leaky_relu", dnnl::algorithm::eltwise_relu}, + {"swish", dnnl::algorithm::eltwise_swish}, + {"hardswish", dnnl::algorithm::eltwise_hardswish}, + {"sqrt", dnnl::algorithm::eltwise_sqrt}, + {"abs", dnnl::algorithm::eltwise_abs}, + {"clip", dnnl::algorithm::eltwise_clip}, + {"gelu", dnnl::algorithm::eltwise_gelu_erf}, + {"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh}, + {"relu6", dnnl::algorithm::eltwise_bounded_relu}, + {"sigmoid", dnnl::algorithm::eltwise_logistic}}; + template void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) { const auto& dev_ctx = @@ -68,7 +104,7 @@ void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) { const float beta = ctx.Attr("beta"); - SoftplusMKLDNNHandler handler(x, beta, mkldnn_engine, ctx.GetPlace()); + SoftplusMKLDNNHandler handler(ctx, x, beta, mkldnn_engine); auto src_memory_p = handler.AcquireSrcMemory(x); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py new file mode 100644 index 00000000000..83c095baeff --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_softplus_activation_fuse_pass.py @@ -0,0 +1,136 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle +import paddle.fluid as fluid +from paddle.fluid.core import PassVersionChecker + + +class SoftplusActivationReluOneDNNFusePassTest(InferencePassTest): + fuse_activation_alpha = None + fuse_activation_beta = None + pass_name = 'softplus_activation_mkldnn_fuse_pass' + + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 100, 100], dtype="float32") + softplus_out = fluid.layers.softplus(data) + if self.fuse_activation_beta is not None: + activation_out = self.fuse_activation( + softplus_out, self.fuse_activation_alpha, + self.fuse_activation_beta) + elif self.fuse_activation_alpha is not None: + activation_out = self.fuse_activation( + softplus_out, self.fuse_activation_alpha) + else: + activation_out = self.fuse_activation(softplus_out) + + self.feeds = { + "data": np.random.random((1, 3, 100, 100)).astype("float32"), + } + self.fetch_list = [activation_out] + self.enable_mkldnn = True + + def set_params(self): + self.fuse_activation = fluid.layers.relu + + def test_check_output(self): + use_gpu = False + self.check_output_with_option(use_gpu) + + def test_pass_compatible(self): + self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name)) + + +class SoftplusActivationTanhOneDNNFusePassTest( + SoftplusActivationReluOneDNNFusePassTest): + def set_params(self): + self.fuse_activation = fluid.layers.tanh + + +class SoftplusActivationLeakyReluOneDNNFusePassTest( + SoftplusActivationReluOneDNNFusePassTest): + def set_params(self): + self.fuse_activation = fluid.layers.leaky_relu + self.fuse_activation_alpha = 0.3 + + +class SoftplusActivationSwishOneDNNFusePassTest( + SoftplusActivationReluOneDNNFusePassTest): + def set_params(self): + self.fuse_activation = fluid.layers.swish + self.fuse_activation_alpha = 3 + + +class SoftplusActivationHardSwishOneDNNFusePassTest( + SoftplusActivationReluOneDNNFusePassTest): + def set_params(self): + self.fuse_activation = fluid.layers.hard_swish + + +class SoftplusActivationSqrtOneDNNFusePassTest( + SoftplusActivationReluOneDNNFusePassTest): + def set_params(self): + self.fuse_activation = fluid.layers.hard_swish + + +class SoftplusActivationAbsOneDNNFusePassTest( + SoftplusActivationReluOneDNNFusePassTest): + def set_params(self): + self.fuse_activation = fluid.layers.abs + + +class SoftplusActivationClipOneDNNFusePassTest( + SoftplusActivationReluOneDNNFusePassTest): + def set_params(self): + self.fuse_activation = fluid.layers.clip + self.fuse_activation_alpha = 1.1 + self.fuse_activation_beta = 5.2 + + +class SoftplusActivationGeluErfOneDNNFusePassTest( + SoftplusActivationReluOneDNNFusePassTest): + def set_params(self): + self.fuse_activation = fluid.layers.gelu + + +class SoftplusActivationGeluTanhOneDNNFusePassTest( + SoftplusActivationReluOneDNNFusePassTest): + def set_params(self): + self.fuse_activation = fluid.layers.gelu + self.fuse_activation_alpha = True # simulated "Approximate" attr + + +class SoftplusActivationRelu6OneDNNFusePassTest( + SoftplusActivationReluOneDNNFusePassTest): + def set_params(self): + self.fuse_activation = fluid.layers.relu6 + + +class SoftplusActivationSigmoidOneDNNFusePassTest( + SoftplusActivationReluOneDNNFusePassTest): + def set_params(self): + self.fuse_activation = fluid.layers.sigmoid + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() -- GitLab