From 70ebef810557efaa3275412de0fb019e43a05c4e Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Thu, 30 Mar 2023 13:54:29 +0800 Subject: [PATCH] [XPU] add delete_concat_op_pass (#52304) --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/delete_concat_op_pass.cc | 119 ++++++++++++++++++ paddle/fluid/framework/ir/pass.cc | 1 + .../inference/api/paddle_pass_builder.cc | 1 + .../test_xpu_delete_concat_op_pass.py | 70 +++++++++++ 5 files changed, 192 insertions(+) create mode 100644 paddle/fluid/framework/ir/delete_concat_op_pass.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_xpu_delete_concat_op_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ebed80b6bc9..75bcdb18209 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -101,6 +101,7 @@ pass_library(delete_op_device_pass inference) pass_library(delete_weight_dequant_linear_op_pass inference) pass_library(delete_quant_dequant_linear_op_pass inference) pass_library(delete_dropout_op_pass inference) +pass_library(delete_concat_op_pass inference) pass_library(delete_c_identity_op_pass inference) pass_library(preln_residual_bias_fuse_pass inference) pass_library(delete_fill_constant_op_pass inference) diff --git a/paddle/fluid/framework/ir/delete_concat_op_pass.cc b/paddle/fluid/framework/ir/delete_concat_op_pass.cc new file mode 100644 index 00000000000..64a573d7a5a --- /dev/null +++ b/paddle/fluid/framework/ir/delete_concat_op_pass.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct ConcatPattern : public PatternBase { + ConcatPattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(any_op); + PATTERN_DECL_NODE(concat); + // declare variable node's name + PATTERN_DECL_NODE(any_op_out); + PATTERN_DECL_NODE(concat_out); +}; + +ConcatPattern::ConcatPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* any_op = pattern->NewNode(any_op_repr())->assert_is_op(); + auto* any_op_out = pattern->NewNode(any_op_out_repr()) + ->assert_is_op_input("concat", "X") + ->assert_has_n_inputs(1) + ->assert_has_n_outputs(1); + auto* concat = pattern->NewNode(concat_repr()) + ->assert_is_op("concat") + ->assert_has_n_inputs(1) + ->assert_more([](Node* node) { + return node->Op()->Input("X").size() == 1; + }); + auto* concat_out = + pattern->NewNode(concat_out_repr())->assert_is_op_output("concat", "Out"); + any_op->LinksTo({any_op_out}); + concat->LinksFrom({any_op_out}).LinksTo({concat_out}); +} + +} // namespace patterns + +/* +Delete "concat" if only has one input. +*/ +class DeleteConcatOpPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"delete_concat_op_pass"}; +}; + +void DeleteConcatOpPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + GraphPatternDetector gpd; + patterns::ConcatPattern pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle DeleteConcatOpPass fuse"; +#define GET_IR_NODE(node_) GET_IR_NODE_FROM_SUBGRAPH(node_, node_, pattern) + GET_IR_NODE(any_op); + GET_IR_NODE(concat); + GET_IR_NODE(any_op_out); + GET_IR_NODE(concat_out); +#undef GET_IR_NODE + + any_op->Op()->RenameOutput(any_op_out->Name(), concat_out->Name()); + IR_NODE_LINK_TO(any_op, concat_out); + + std::unordered_set delete_nodes{any_op_out, concat}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_concat_op_pass, paddle::framework::ir::DeleteConcatOpPass); + +REGISTER_PASS_CAPABILITY(delete_concat_op_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "concat", 0)); diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index b48b3606594..548d9436000 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -53,6 +53,7 @@ static const std::vector support_subgraph_passes = { static const std::vector xpu_support_subgraph_passes = { "delete_dropout_op_pass", + "delete_concat_op_pass", "identity_scale_op_clean_pass", "delete_op_device_pass", "constant_folding_pass", diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 2e2a896f754..99a6f336ff3 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -519,6 +519,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() { XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { passes_.assign({ "delete_dropout_op_pass", + "delete_concat_op_pass", "identity_scale_op_clean_pass", "delete_op_device_pass", "constant_folding_pass", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_delete_concat_op_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_delete_concat_op_pass.py new file mode 100644 index 00000000000..0d35f5a5dc1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_delete_concat_op_pass.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import hypothesis.strategies as st +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestDeleteConcatOpPass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["relu"], (1e-5, 1e-5) + + def sample_program_config(self, draw): + x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=4), min_size=2, max_size=4 + ) + ) + + relu_op = OpConfig( + "relu", + inputs={ + "X": ["relu_x"], + }, + outputs={"Out": ["relu_out"]}, + ) + concat_op = OpConfig( + "concat", + inputs={ + "X": ["relu_out"], + }, + axis=0, + outputs={"Out": ["concat_out"]}, + ) + ops = [relu_op, concat_op] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "relu_x": TensorConfig(shape=x_shape), + }, + outputs=ops[-1].outputs["Out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["delete_concat_op_pass"], + ) + + +if __name__ == "__main__": + unittest.main() -- GitLab