未验证 提交 813266a2 编写于 作者: Z zhupengyang 提交者: GitHub

delete_assign_op_pass (#54887)

上级 b58869fa
......@@ -100,6 +100,7 @@ pass_library(trt_delete_weight_dequant_linear_op_pass inference)
pass_library(delete_op_device_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_assign_op_pass inference)
pass_library(delete_dropout_op_pass inference)
pass_library(delete_concat_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
......@@ -423,6 +424,10 @@ cc_test(
test_delete_op_device_pass
SRCS delete_op_device_pass_test.cc
DEPS delete_op_device_pass)
cc_test(
test_delete_assign_op_pass_cc
SRCS delete_assign_op_pass_test.cc
DEPS delete_assign_op_pass)
cc_test(
test_delete_dropout_pass_cc
SRCS delete_dropout_op_pass_test.cc
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct AssignWithSameInputOutputNamePattern : public PatternBase {
AssignWithSameInputOutputNamePattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(assign);
};
AssignWithSameInputOutputNamePattern::AssignWithSameInputOutputNamePattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
pattern->NewNode(assign_repr())
->assert_is_op("assign")
->assert_more([](Node* node) {
auto in_name = node->Op()->Input("X")[0];
auto out_name = node->Op()->Output("Out")[0];
return in_name == out_name;
});
}
} // namespace patterns
/*
Delete "assign" if its input and output is same.
*/
class DeleteAssignOpPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"delete_assign_op_pass"};
};
void DeleteAssignOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::AssignWithSameInputOutputNamePattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle DeleteAssignOpPass fuse";
GET_IR_NODE_FROM_SUBGRAPH(assign, assign, pattern);
std::unordered_set<const Node*> delete_nodes{assign};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_assign_op_pass, paddle::framework::ir::DeleteAssignOpPass);
REGISTER_PASS_CAPABILITY(delete_assign_op_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"assign", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(delete_assign_op_pass, basic) {
ProgramDesc program;
auto* x_var = program.MutableBlock(0)->Var("assign_x");
auto* out_var = program.MutableBlock(0)->Var("assign_out");
out_var->SetName(x_var->Name());
OpDesc* assign_op = program.MutableBlock(0)->AppendOp();
assign_op->SetType("assign");
assign_op->SetInput("X", {x_var->Name()});
assign_op->SetOutput("Out", {out_var->Name()});
std::unique_ptr<Graph> graph(new Graph(program));
auto pass = PassRegistry::Instance().Get("delete_assign_op_pass");
graph.reset(pass->Apply(graph.release()));
int assign_num = GetNumOpNodes(graph, "assign");
PADDLE_ENFORCE_EQ(
assign_num,
0,
platform::errors::PreconditionNotMet(
"graph should have 0 assign after delete_assign_op_pass, "
"but actually has %d.",
assign_num));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(delete_assign_op_pass);
......@@ -52,6 +52,7 @@ static const std::vector<std::string> support_subgraph_passes = {
};
static const std::vector<std::string> xpu_support_subgraph_passes = {
"delete_assign_op_pass",
"delete_dropout_op_pass",
"delete_concat_op_pass",
"identity_op_clean_pass",
......
......@@ -508,6 +508,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() {
XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"delete_assign_op_pass",
"delete_dropout_op_pass",
"delete_concat_op_pass",
"identity_op_clean_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册