diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 021a494ef3c5452e95c18f8da76e7db1a0eac5e0..4b13152f554946a52137989456f8600f91957c40 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -100,6 +100,7 @@ pass_library(trt_delete_weight_dequant_linear_op_pass inference) pass_library(delete_op_device_pass inference) pass_library(delete_weight_dequant_linear_op_pass inference) pass_library(delete_quant_dequant_linear_op_pass inference) +pass_library(delete_assign_op_pass inference) pass_library(delete_dropout_op_pass inference) pass_library(delete_concat_op_pass inference) pass_library(preln_residual_bias_fuse_pass inference) @@ -423,6 +424,10 @@ cc_test( test_delete_op_device_pass SRCS delete_op_device_pass_test.cc DEPS delete_op_device_pass) +cc_test( + test_delete_assign_op_pass_cc + SRCS delete_assign_op_pass_test.cc + DEPS delete_assign_op_pass) cc_test( test_delete_dropout_pass_cc SRCS delete_dropout_op_pass_test.cc diff --git a/paddle/fluid/framework/ir/delete_assign_op_pass.cc b/paddle/fluid/framework/ir/delete_assign_op_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..f032ef4216136826456ac012aacacb27552d1524 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_assign_op_pass.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct AssignWithSameInputOutputNamePattern : public PatternBase { + AssignWithSameInputOutputNamePattern(PDPattern* pattern, + const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(assign); +}; + +AssignWithSameInputOutputNamePattern::AssignWithSameInputOutputNamePattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + pattern->NewNode(assign_repr()) + ->assert_is_op("assign") + ->assert_more([](Node* node) { + auto in_name = node->Op()->Input("X")[0]; + auto out_name = node->Op()->Output("Out")[0]; + return in_name == out_name; + }); +} + +} // namespace patterns + +/* +Delete "assign" if its input and output is same. +*/ +class DeleteAssignOpPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"delete_assign_op_pass"}; +}; + +void DeleteAssignOpPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + GraphPatternDetector gpd; + patterns::AssignWithSameInputOutputNamePattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle DeleteAssignOpPass fuse"; + GET_IR_NODE_FROM_SUBGRAPH(assign, assign, pattern); + + std::unordered_set delete_nodes{assign}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_assign_op_pass, paddle::framework::ir::DeleteAssignOpPass); + +REGISTER_PASS_CAPABILITY(delete_assign_op_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "assign", 0)); diff --git a/paddle/fluid/framework/ir/delete_assign_op_pass_test.cc b/paddle/fluid/framework/ir/delete_assign_op_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..92477747fe2beaa83e8977ac957d72bd2f9bd5ce --- /dev/null +++ b/paddle/fluid/framework/ir/delete_assign_op_pass_test.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(delete_assign_op_pass, basic) { + ProgramDesc program; + auto* x_var = program.MutableBlock(0)->Var("assign_x"); + auto* out_var = program.MutableBlock(0)->Var("assign_out"); + out_var->SetName(x_var->Name()); + OpDesc* assign_op = program.MutableBlock(0)->AppendOp(); + assign_op->SetType("assign"); + assign_op->SetInput("X", {x_var->Name()}); + assign_op->SetOutput("Out", {out_var->Name()}); + + std::unique_ptr graph(new Graph(program)); + auto pass = PassRegistry::Instance().Get("delete_assign_op_pass"); + graph.reset(pass->Apply(graph.release())); + int assign_num = GetNumOpNodes(graph, "assign"); + PADDLE_ENFORCE_EQ( + assign_num, + 0, + platform::errors::PreconditionNotMet( + "graph should have 0 assign after delete_assign_op_pass, " + "but actually has %d.", + assign_num)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(delete_assign_op_pass); diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 59b43d87447a59a0e44d569eb8dc98bfee5694d9..f11f52a0b1cdaaea3673b25828f8e4e7d2f3cf18 100755 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -52,6 +52,7 @@ static const std::vector support_subgraph_passes = { }; static const std::vector xpu_support_subgraph_passes = { + "delete_assign_op_pass", "delete_dropout_op_pass", "delete_concat_op_pass", "identity_op_clean_pass", diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc old mode 100755 new mode 100644 index 95285f9930181d2213e798b2fc9cdb25802ebb6f..41b90c63b1ead057700dce405e69210b69c92be7 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -508,6 +508,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() { XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { passes_.assign({ + "delete_assign_op_pass", "delete_dropout_op_pass", "delete_concat_op_pass", "identity_op_clean_pass",