diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 5eb94dbbce435da99a6da4a22bb985b9adf69bd8..8130c91a31373455e90ce2e2bbc9a8a5a98774b8 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -127,6 +127,7 @@ pass_library(dense_multihead_matmul_to_sparse_pass inference) pass_library(delete_cast_op_pass inference) pass_library(delete_elementwise_mul_op_pass inference) pass_library(delete_repeated_ops_pass inference) +pass_library(sigmoid_elementmul_fuse_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) diff --git a/paddle/fluid/framework/ir/sigmoid_elementmul_fuse_pass.cc b/paddle/fluid/framework/ir/sigmoid_elementmul_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..6904e5604fb5c75454681178a35c505334d99f9b --- /dev/null +++ b/paddle/fluid/framework/ir/sigmoid_elementmul_fuse_pass.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/sigmoid_elementmul_fuse_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct SigmoidElementmulFusePattern : public PatternBase { + SigmoidElementmulFusePattern(PDPattern* pattern, + const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(sigmoid); + PATTERN_DECL_NODE(elementwise_mul); + // declare variable node's name + PATTERN_DECL_NODE(sigmoid_x); + PATTERN_DECL_NODE(sigmoid_out); + PATTERN_DECL_NODE(elemul_out); +}; + +SigmoidElementmulFusePattern::SigmoidElementmulFusePattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* sigmoid_x = pattern->NewNode(sigmoid_x_repr()) + ->assert_is_op_input("sigmoid", "X") + ->assert_var_not_persistable(); + + auto* sigmoid_op = pattern->NewNode(sigmoid_repr())->assert_is_op("sigmoid"); + + auto* sigmoid_out = pattern->NewNode(sigmoid_out_repr()) + ->assert_is_op_output("sigmoid", "Out") + ->assert_var_not_persistable(); + + auto* elemul_op = + pattern->NewNode(elementwise_mul_repr())->assert_is_op("elementwise_mul"); + + auto* elemul_out = pattern->NewNode(elemul_out_repr()) + ->assert_is_op_output("elementwise_mul", "Out") + ->assert_var_not_persistable(); + + sigmoid_op->LinksFrom({sigmoid_x}).LinksTo({sigmoid_out}); + elemul_op->LinksFrom({sigmoid_x, sigmoid_out}).LinksTo({elemul_out}); +} + +} // namespace patterns + +SigmoidElementmulFusePass::SigmoidElementmulFusePass() {} + +void SigmoidElementmulFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + GraphPatternDetector gpd; + patterns::SigmoidElementmulFusePattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle SigmoidElementmulFusePass fuse"; +#define GET_IR_NODE(node_) GET_IR_NODE_FROM_SUBGRAPH(node_, node_, pattern) + GET_IR_NODE(sigmoid_x); + GET_IR_NODE(sigmoid); + GET_IR_NODE(sigmoid_out); + GET_IR_NODE(elementwise_mul); + GET_IR_NODE(elemul_out); +#undef GET_IR_NODE + auto* block = sigmoid->Op()->Block(); + std::string elemul_out_name = elemul_out->Name(); + + // Generate swish op + framework::OpDesc swish_op_desc(block); + swish_op_desc.SetType("swish"); + swish_op_desc.SetInput("X", {sigmoid_x->Name()}); + swish_op_desc.SetAttr("beta", 1.f); + swish_op_desc.SetOutput("Out", {elemul_out_name}); + + auto* swish = graph->CreateOpNode(&swish_op_desc); + IR_NODE_LINK_TO(sigmoid_x, swish); + IR_NODE_LINK_TO(swish, elemul_out); + + // delete useless node + std::unordered_set delete_nodes; + delete_nodes = {sigmoid, sigmoid_out, elementwise_mul}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(sigmoid_elementmul_fuse_pass, + paddle::framework::ir::SigmoidElementmulFusePass); + +REGISTER_PASS_CAPABILITY(sigmoid_elementmul_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "swish", 0)); diff --git a/paddle/fluid/framework/ir/sigmoid_elementmul_fuse_pass.h b/paddle/fluid/framework/ir/sigmoid_elementmul_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..6d116ad9c81100599c517d47c1fa65a8a9fe532c --- /dev/null +++ b/paddle/fluid/framework/ir/sigmoid_elementmul_fuse_pass.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +/* +1. fuse sigmoid + elementwise_mul into swish + +Origin subgraph: + input + / \ + | | + | sigmoid + | | + | | + elementwise_mul + | + | + out + +Fused subgraph: + input + | + | + swish + | + | + out +*/ +class SigmoidElementmulFusePass : public FusePassBase { + public: + SigmoidElementmulFusePass(); + virtual ~SigmoidElementmulFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"sigmoid_elementmul_fuse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index bea0b82ecd4941d48c4ef4c52dfc6547a15c3797..f9123a111771f1dafd9b0221c37cb61a957bc07b 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -522,6 +522,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "delete_cast_op_pass", "stack_fuse_pass", "fused_multi_transformer_xpu_pass", + "sigmoid_elementmul_fuse_pass", "fc_xpu_fuse_pass", "conv2d_xpu_fuse_pass", "link_xpu_op_max_pass", diff --git a/test/ir/inference/test_xpu_sigmoid_elementmul_fuse_pass.py b/test/ir/inference/test_xpu_sigmoid_elementmul_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..e6a348b30a8c95bc3e293ab8ddf79853838f49ba --- /dev/null +++ b/test/ir/inference/test_xpu_sigmoid_elementmul_fuse_pass.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestSigmoidElementmulFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["swish"], (1e-3, 1e-3) + + def sample_program_config(self, draw): + # 1. sigmoid + sigmoid_x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=4), min_size=2, max_size=4 + ) + ) + + sigmoid_op = OpConfig( + "sigmoid", + inputs={"X": ["sigmoid_x"]}, + outputs={"Out": ["sigmoid_out"]}, + trans_x=False, + trans_y=False, + ) + mul_op = OpConfig( + "elementwise_mul", + inputs={"X": ["sigmoid_x"], "Y": ["sigmoid_out"]}, + outputs={"Out": ["out"]}, + axis=-1, + ) + ops = [sigmoid_op, mul_op] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "sigmoid_x": TensorConfig(shape=sigmoid_x_shape), + }, + outputs=ops[-1].outputs["Out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["sigmoid_elementmul_fuse_pass"], + ) + + +if __name__ == "__main__": + np.random.seed(200) + unittest.main()