diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc index cdb0f70a56667d3f50801db138a1d40563981f7f..f4ac65a9ab1993d990f52191515d0cd5d4b6cd44 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,9 +14,8 @@ #include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { @@ -26,20 +25,20 @@ namespace ir { using string::PrettyLogDetail; void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { - std::vector act_types = { - "gelu", "tanh", "sigmoid", "mish", "hard_swish"}; + auto act_types = paddle::platform::GetSupportedActivations(); - for (std::string act_type : act_types) FuseFCAct(graph, act_type); + for (auto act_type : act_types) FuseFCAct(graph, act_type); } void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, const std::string &act_type) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - FusePassBase::Init("fc_act", graph); + FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph); GraphPatternDetector gpd; - patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act"); + patterns::OperatorActivation fc_act_pattern( + gpd.mutable_pattern(), "fc_" + act_type + "_mkldnn_fuse_pass"); fc_act_pattern("fc", act_type); int found_fc_act_count = 0; @@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, "is used.")); } + auto attr_map = paddle::platform::GetAttributeMap(act_type); + for (const auto &attr : attr_map) { + if (act_op->HasAttr(attr.first)) { + fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first)); + } + } + if (act_type == "gelu" && act_op->HasAttr("approximate")) { - bool approximate = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")); - std::string type = approximate ? "_tanh" : "_erf"; - fc_op->SetAttr("activation_type", act_type + type); + std::string gelu_act_type = + PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh" + : "gelu_erf"; + fc_op->SetAttr("fuse_activation", gelu_act_type); } else { - fc_op->SetAttr("activation_type", act_type); + fc_op->SetAttr("fuse_activation", act_type); } - fc_op->SetAttr("use_mkldnn", true); + fc_op->SetAttr("use_mkldnn", true); fc_op->SetOutput("Out", {act_out->Name()}); IR_OP_VAR_LINK(fc, act_out); @@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, gpd(graph, handler); AddStatis(found_fc_act_count); - if (!Has("disable_logs") || !Get("disable_logs")) + if ((!Has("disable_logs") || !Get("disable_logs")) && + found_fc_act_count > 0) PrettyLogDetail( "--- fused %d fc with %s activation", found_fc_act_count, act_type); } @@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("fc", 0) - .LE("gelu", 0) - .LE("sigmoid", 0) - .LE("mish", 1) + .EQ("abs", 0) + .LE("clip", 1) + .EQ("gelu", 0) + .EQ("hard_sigmoid", 0) .LE("hard_swish", 0) - .LE("tanh", 0)); + .LE("leaky_relu", 1) + .LE("mish", 1) + .EQ("relu", 0) + .EQ("relu6", 0) + .EQ("sigmoid", 0) + .EQ("sqrt", 0) + .EQ("swish", 0) + .EQ("tanh", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h index 23f4296b98bcabab17c896a7ea0c80f72e358e06..7e4032d4a135292baa600146718627a486e5149e 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h @@ -23,21 +23,14 @@ namespace paddle { namespace framework { namespace ir { -/* - * \brief Fuse the FC and activation operators into single OneDNN's - * FC with post-op. - * - * \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported - * as an activation function. - */ class FuseFCActOneDNNPass : public FusePassBase { public: virtual ~FuseFCActOneDNNPass() {} protected: - void ApplyImpl(ir::Graph *graph) const override; + void ApplyImpl(Graph *graph) const override; - void FuseFCAct(ir::Graph *graph, const std::string &act_types) const; + void FuseFCAct(Graph *graph, const std::string &act_types) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc index 38f253703ceeec75f45ac89639e6483f65f472b5..643d43913800500dad6130628741de0966988d00 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc @@ -34,12 +34,12 @@ TEST(FuseFCActOneDNNPass, ThrowUseMkldnn) { "fc", { {"Input", "x"}, - {"Weights", "weights"}, + {"W", "weights"}, {"Bias", "bias"}, }, {{"Out", "fc_y"}}, false); - test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); // No fusion in this attribute configuration @@ -58,12 +58,12 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) { "fc", { {"Input", "x"}, - {"Weights", "weights"}, + {"W", "weights"}, {"Bias", "bias"}, }, {{"Out", "fc_y"}}); - auto* act_op = test::CreateOp( - &prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + auto* act_op = + test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false); act_op->SetAttr("approximate", true); Graph graph(prog); @@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("gelu_tanh"), 0); } } @@ -93,12 +93,12 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) { "fc", { {"Input", "x"}, - {"Weights", "weights"}, + {"W", "weights"}, {"Bias", "bias"}, }, {{"Out", "fc_y"}}); - auto* act_op = test::CreateOp( - &prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + auto* act_op = + test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false); act_op->SetAttr("approximate", false); Graph graph(prog); @@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("gelu_erf"), 0); } } @@ -128,11 +128,11 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) { "fc", { {"Input", "x"}, - {"Weights", "weights"}, + {"W", "weights"}, {"Bias", "bias"}, }, {{"Out", "fc_y"}}); - test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); constexpr int removed_nodes_count = 2; @@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("gelu"), 0); } } @@ -161,11 +161,11 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) { "fc", { {"Input", "x"}, - {"Weights", "weights"}, + {"W", "weights"}, {"Bias", "bias"}, }, {{"Out", "fc_y"}}); - test::CreateOp(&prog, "tanh", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + test::CreateOp(&prog, "tanh", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); constexpr int removed_nodes_count = 2; @@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("tanh"), 0); } } @@ -194,12 +194,11 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { "fc", { {"Input", "x"}, - {"Weights", "weights"}, + {"W", "weights"}, {"Bias", "bias"}, }, {{"Out", "fc_y"}}); - test::CreateOp( - &prog, "sigmoid", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + test::CreateOp(&prog, "sigmoid", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); constexpr int removed_nodes_count = 2; @@ -213,9 +212,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("sigmoid"), 0); } } @@ -228,11 +227,11 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) { "fc", { {"Input", "x"}, - {"Weights", "weights"}, + {"W", "weights"}, {"Bias", "bias"}, }, {{"Out", "fc_y"}}); - test::CreateOp(&prog, "mish", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + test::CreateOp(&prog, "mish", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); constexpr int removed_nodes_count = 2; @@ -246,9 +245,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("mish"), 0); } } @@ -261,12 +260,12 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) { "fc", { {"Input", "x"}, - {"Weights", "weights"}, + {"W", "weights"}, {"Bias", "bias"}, }, {{"Out", "fc_y"}}); test::CreateOp( - &prog, "hard_swish", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false); + &prog, "hard_swish", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false); Graph graph(prog); constexpr int removed_nodes_count = 2; @@ -280,9 +279,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) { const auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("activation_type")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); auto act_type = - PADDLE_GET_CONST(std::string, op->GetAttr("activation_type")); + PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation")); EXPECT_EQ(act_type.compare("hard_swish"), 0); } } diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc index a9b17634531b0b92482807acec0e85c757cb78a1..2512a85357934e665fb02153c739cf9ea5221003 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -242,7 +242,7 @@ bool OpCompat::Judge(const OpDesc& op_desc, const std::string& pass_name) { if (input_compats_.find(input_desc.first) == input_compats_.end()) { if (!input_desc.second.empty()) { LOG(WARNING) << "The Input (" << input_desc.first << ") of Operator (" - << op_name_ << ") not reigistered in OpCompat!"; + << op_name_ << ") not registered in OpCompat!"; return false; } } @@ -269,7 +269,7 @@ bool OpCompat::Judge(const OpDesc& op_desc, const std::string& pass_name) { if (output_compats_.find(output_desc.first) == output_compats_.end()) { if (!output_desc.second.empty()) { LOG(WARNING) << "The Output (" << output_desc.first << ") of Operator (" - << op_name_ << ") not reigistered in OpCompat!"; + << op_name_ << ") not registered in OpCompat!"; return false; } } diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 5dfb81350b5c1f4c68074a4bf4e4a2e1be9facd9..12d2bdef79148407c57708804a0f25a764e5a124 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -87,8 +87,7 @@ class FCMKLDNNHandler dnnl::memory::format_tag::a); } - dnnl::primitive_attr attrs; - HandlePostOps(ctx, &attrs); + const auto attrs = CreateFCAttrs(ctx); this->AcquireForwardPrimitiveDescriptor(attrs, prop_kind::forward_inference, @@ -99,44 +98,33 @@ class FCMKLDNNHandler } private: - void HandlePostOps(const paddle::framework::ExecutionContext& ctx, - dnnl::primitive_attr* attrs) { - static std::unordered_map algo_map = { - {"relu", dnnl::algorithm::eltwise_relu}, - {"gelu", dnnl::algorithm::eltwise_gelu}, - {"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh}, - {"gelu_erf", dnnl::algorithm::eltwise_gelu_erf}, - {"tanh", dnnl::algorithm::eltwise_tanh}, - {"sigmoid", dnnl::algorithm::eltwise_logistic}, - {"hard_swish", dnnl::algorithm::eltwise_hardswish}, - {"mish", dnnl::algorithm::eltwise_mish}}; + dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) { + dnnl::primitive_attr attributes; + dnnl::post_ops post_operations; std::vector output_shift_scale; float scale = 1.0f; if (IsInt8()) { std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx); int mask = CreateMask(1, output_shift_scale.size() > 1); - attrs->set_output_scales(mask, output_shift_scale); + attributes.set_output_scales(mask, output_shift_scale); } - dnnl::post_ops post_ops; - - constexpr float sum_scale = 1.0f; + float sum_scale = 1.0f; if (ctx.HasAttr("fuse_residual_connection") && ctx.Attr("fuse_residual_connection")) { - post_ops.append_sum(sum_scale); + post_operations.append_sum(sum_scale); } - std::string activation_type = ctx.Attr("activation_type"); - - if (activation_type.empty() == false) { - constexpr float alpha = 0.0f; - constexpr float beta = 0.0f; - - post_ops.append_eltwise(scale, algo_map[activation_type], alpha, beta); + // ReLU from "fc_fuse_pass" + if (ctx.Attr("activation_type") == "relu") { + post_operations.append_eltwise( + scale, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f); } + platform::AppendActivation(ctx, post_operations, scale); - attrs->set_post_ops(post_ops); + attributes.set_post_ops(post_operations); + return attributes; } // Compute the bias scales so that its values correspond to the diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index fbe6fbf26f81114f08825ea5bcf8f87cb742f822..731ec7306fba5e3bb66b0692f7f957deb875c3c0 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -226,7 +226,8 @@ if(WITH_GPU AND TENSORRT_FOUND) set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 300) set_tests_properties(test_mkldnn_conv_mish_fuse_pass PROPERTIES TIMEOUT 300) - set_tests_properties(test_mkldnn_fc_mish_fuse_pass PROPERTIES TIMEOUT 300) + set_tests_properties(test_onednn_fc_activation_fuse_pass PROPERTIES TIMEOUT + 300) set_tests_properties(test_mkldnn_fc_elementwise_add_fuse_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_mkldnn_conv_affine_channel_fuse_pass diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_mish_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_mish_fuse_pass.py deleted file mode 100644 index dd9321b6a74be242ae1cf003c71efa34bbc197d6..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_fc_mish_fuse_pass.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from auto_scan_test import PassAutoScanTest -from program_config import TensorConfig, ProgramConfig -import numpy as np -import unittest -import hypothesis.strategies as st - - -class TestFCMishMkldnnFusePass(PassAutoScanTest): - - def sample_program_config(self, draw): - x_shape = draw( - st.lists(st.integers(min_value=1, max_value=128), - min_size=2, - max_size=3)) - in_num_col_dims = len(x_shape) - 1 - w_shape = draw( - st.lists(st.integers(min_value=1, max_value=128), - min_size=2, - max_size=2)) - w_shape[0] = int(np.prod(x_shape[in_num_col_dims:])) - fc_bias_shape = [w_shape[1]] - - ops_config = [{ - "op_type": "fc", - "op_inputs": { - "Input": ["fc_x"], - "W": ["fc_w"], - "Bias": ["fc_bias"] - }, - "op_outputs": { - "Out": ["fc_out"] - }, - "op_attrs": { - "activation_type": "", - "padding_weights": False, - "in_num_col_dims": in_num_col_dims, - "use_mkldnn": True - } - }, { - "op_type": "mish", - "op_inputs": { - "X": ["fc_out"] - }, - "op_outputs": { - "Out": ["mish_output"] - }, - "op_attrs": {}, - }] - - ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig(ops=ops, - weights={ - "fc_w": - TensorConfig(shape=w_shape), - "fc_bias": - TensorConfig(shape=fc_bias_shape), - }, - inputs={ - "fc_x": TensorConfig(shape=x_shape), - }, - outputs=["mish_output"]) - return program_config - - def sample_predictor_configs(self, program_config): - config = self.create_inference_config( - use_mkldnn=True, passes=["fc_act_mkldnn_fuse_pass"]) - yield config, ["fc"], (1e-5, 1e-5) - - def test(self): - self.run_and_statis(quant=False, passes=["fc_act_mkldnn_fuse_pass"]) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_fc_activation_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_fc_activation_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..4e798a0ed57d5338ab8e6c0b7a3c7f622d6ff57d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_onednn_fc_activation_fuse_pass.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest +import hypothesis.strategies as st + + +class TestFCActivationOneDNNFusePass(PassAutoScanTest): + + def sample_program_config(self, draw): + fc_in = draw(st.sampled_from([32, 64])) + fc_wei = draw(st.sampled_from([64])) + activation_type = draw( + st.sampled_from([ + 'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish', + 'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid', + 'leaky_relu' + ])) + + def generate_input(shape): + return np.random.random(shape).astype(np.float32) + + fc_op = OpConfig(type="fc", + inputs={ + "Input": ["fc_input"], + "W": ["fc_weight"], + "Bias": ["fc_bias"] + }, + outputs={"Out": ["fc_output"]}, + attrs={ + "use_mkldnn": True, + "padding_weights": False, + "in_num_col_dims": 1, + }) + + if activation_type == "clip": + activation_op = OpConfig( + activation_type, + inputs={"X": ["fc_output"]}, + outputs={"Out": ["activation_output"]}, + min=draw(st.floats(min_value=0.1, max_value=0.49)), + max=draw(st.floats(min_value=0.5, max_value=1.0))) + elif activation_type == "gelu": + activation_op = OpConfig(activation_type, + inputs={"X": ["fc_output"]}, + outputs={"Out": ["activation_output"]}, + approximate=draw(st.booleans())) + elif activation_type == "leaky_relu": + activation_op = OpConfig(activation_type, + inputs={"X": ["fc_output"]}, + outputs={"Out": ["activation_output"]}, + alpha=draw( + st.floats(min_value=0.1, + max_value=1.0))) + elif activation_type == "relu6": + activation_op = OpConfig(activation_type, + inputs={"X": ["fc_output"]}, + outputs={"Out": ["activation_output"]}, + threshold=6) + elif activation_type == "swish": + activation_op = OpConfig(activation_type, + inputs={"X": ["fc_output"]}, + outputs={"Out": ["activation_output"]}, + beta=draw( + st.floats(min_value=0.1, + max_value=10.0))) + else: + activation_op = OpConfig(activation_type, + inputs={"X": ["fc_output"]}, + outputs={"Out": ["activation_output"]}) + + model_net = [fc_op, activation_op] + + program_config = ProgramConfig( + ops=model_net, + weights={ + "fc_weight": + TensorConfig( + data_gen=partial(generate_input, [fc_wei, fc_wei])), + "fc_bias": + TensorConfig(data_gen=partial(generate_input, [fc_wei])), + }, + inputs={ + "fc_input": + TensorConfig(data_gen=partial(generate_input, [fc_in, fc_wei])) + }, + outputs=["activation_output"]) + + return program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config( + use_mkldnn=True, passes=["fc_act_mkldnn_fuse_pass"]) + yield config, ["fc"], (1e-5, 1e-5) + + def test(self): + self.run_and_statis(quant=False, passes=["fc_act_mkldnn_fuse_pass"]) + + +if __name__ == "__main__": + unittest.main()