未验证 提交 70ebef81 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] add delete_concat_op_pass (#52304)

上级 8ef97088
......@@ -101,6 +101,7 @@ pass_library(delete_op_device_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference)
pass_library(delete_concat_op_pass inference)
pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct ConcatPattern : public PatternBase {
ConcatPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(any_op);
PATTERN_DECL_NODE(concat);
// declare variable node's name
PATTERN_DECL_NODE(any_op_out);
PATTERN_DECL_NODE(concat_out);
};
ConcatPattern::ConcatPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* any_op = pattern->NewNode(any_op_repr())->assert_is_op();
auto* any_op_out = pattern->NewNode(any_op_out_repr())
->assert_is_op_input("concat", "X")
->assert_has_n_inputs(1)
->assert_has_n_outputs(1);
auto* concat = pattern->NewNode(concat_repr())
->assert_is_op("concat")
->assert_has_n_inputs(1)
->assert_more([](Node* node) {
return node->Op()->Input("X").size() == 1;
});
auto* concat_out =
pattern->NewNode(concat_out_repr())->assert_is_op_output("concat", "Out");
any_op->LinksTo({any_op_out});
concat->LinksFrom({any_op_out}).LinksTo({concat_out});
}
} // namespace patterns
/*
Delete "concat" if only has one input.
*/
class DeleteConcatOpPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"delete_concat_op_pass"};
};
void DeleteConcatOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::ConcatPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle DeleteConcatOpPass fuse";
#define GET_IR_NODE(node_) GET_IR_NODE_FROM_SUBGRAPH(node_, node_, pattern)
GET_IR_NODE(any_op);
GET_IR_NODE(concat);
GET_IR_NODE(any_op_out);
GET_IR_NODE(concat_out);
#undef GET_IR_NODE
any_op->Op()->RenameOutput(any_op_out->Name(), concat_out->Name());
IR_NODE_LINK_TO(any_op, concat_out);
std::unordered_set<const Node*> delete_nodes{any_op_out, concat};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_concat_op_pass, paddle::framework::ir::DeleteConcatOpPass);
REGISTER_PASS_CAPABILITY(delete_concat_op_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"concat", 0));
......@@ -53,6 +53,7 @@ static const std::vector<std::string> support_subgraph_passes = {
static const std::vector<std::string> xpu_support_subgraph_passes = {
"delete_dropout_op_pass",
"delete_concat_op_pass",
"identity_scale_op_clean_pass",
"delete_op_device_pass",
"constant_folding_pass",
......
......@@ -519,6 +519,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() {
XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"delete_dropout_op_pass",
"delete_concat_op_pass",
"identity_scale_op_clean_pass",
"delete_op_device_pass",
"constant_folding_pass",
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestDeleteConcatOpPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["relu"], (1e-5, 1e-5)
def sample_program_config(self, draw):
x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=2, max_size=4
)
)
relu_op = OpConfig(
"relu",
inputs={
"X": ["relu_x"],
},
outputs={"Out": ["relu_out"]},
)
concat_op = OpConfig(
"concat",
inputs={
"X": ["relu_out"],
},
axis=0,
outputs={"Out": ["concat_out"]},
)
ops = [relu_op, concat_op]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"relu_x": TensorConfig(shape=x_shape),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["delete_concat_op_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册