未验证 提交 2363e623 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Inference] rewrite identity_op_clean_pass (#55240)

* rewrite identity_op_clean_pass

* fix

* adjust identity_op_clean_pass order in gpu passes

* fix ut
上级 006bd959
...@@ -434,6 +434,10 @@ cc_test( ...@@ -434,6 +434,10 @@ cc_test(
test_delete_assign_op_pass_cc test_delete_assign_op_pass_cc
SRCS delete_assign_op_pass_test.cc SRCS delete_assign_op_pass_test.cc
DEPS delete_assign_op_pass) DEPS delete_assign_op_pass)
cc_test(
test_identity_op_clean_pass_cc
SRCS identity_op_clean_pass_test.cc
DEPS identity_op_clean_pass)
cc_test( cc_test(
test_delete_dropout_pass_cc test_delete_dropout_pass_cc
SRCS delete_dropout_op_pass_test.cc SRCS delete_dropout_op_pass_test.cc
......
...@@ -21,24 +21,33 @@ namespace paddle { ...@@ -21,24 +21,33 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class Graph; namespace patterns {
void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const { // pre_op -> useless_op_in -> useless_op -> useless_op_out
FusePassBase::Init("identity_scale_op_clean", graph); // ->
// pre_op -> useless_op_out
// pre_op -> useless_op_in -> useless_op -> useless_op_out struct FindUselessOpPattern : public PatternBase {
// -> FindUselessOpPattern(PDPattern* pattern, const std::string& name_scope);
// pre_op -> useless_op_out
GraphPatternDetector detector; // declare operator node's name
auto useless_op_in = PATTERN_DECL_NODE(useless_op_in);
detector.mutable_pattern() PATTERN_DECL_NODE(useless_op);
->NewNode("useless_op_in") PATTERN_DECL_NODE(useless_op_out);
->assert_has_n_outputs(1) };
FindUselessOpPattern::FindUselessOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* useless_op_in = pattern->NewNode(useless_op_in_repr())
->assert_is_var()
->assert_var_not_persistable() ->assert_var_not_persistable()
->assert_has_n_outputs(1)
->assert_more([](Node* x) { ->assert_more([](Node* x) {
for (auto* op : x->inputs) { for (auto* op : x->inputs) {
auto op_type = op->Op()->Type(); CHECK_EQ(op->IsOp(), true);
if (op_type == "conditional_block" || op_type == "while") { const auto& op_type = op->Op()->Type();
if (op_type == "conditional_block" ||
op_type == "while" || op_type == "feed") {
return false; return false;
} }
} }
...@@ -46,79 +55,71 @@ void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const { ...@@ -46,79 +55,71 @@ void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
}); });
// This useless_op must have only one input and one output! // This useless_op must have only one input and one output!
auto useless_op = auto* useless_op =
detector.mutable_pattern() pattern->NewNode(useless_op_repr())
->NewNode("useless_op") ->assert_is_op()
->assert_has_n_inputs(1) ->assert_has_n_inputs(1)
->assert_has_n_outputs(1) ->assert_has_n_outputs(1)
->assert_more([](Node* x) { ->assert_more([](Node* x) {
if (!x->IsOp()) { const auto& op_type = x->Op()->Type();
return false; if (op_type == "scale") {
}
if (x->Op()->Type() == "scale") {
auto scale = x->Op()->GetAttrIfExists<float>("scale"); auto scale = x->Op()->GetAttrIfExists<float>("scale");
auto bias = x->Op()->GetAttrIfExists<float>("bias"); auto bias = x->Op()->GetAttrIfExists<float>("bias");
if (bias == 0 && scale == 1) { return bias == 0.f && scale == 1.f;
return true; } else if (op_type == "cast") {
}
}
if (x->Op()->Type() == "cast") {
auto in_dtype = x->Op()->GetAttrIfExists<int>("in_dtype"); auto in_dtype = x->Op()->GetAttrIfExists<int>("in_dtype");
auto out_dtype = x->Op()->GetAttrIfExists<int>("out_dtype"); auto out_dtype = x->Op()->GetAttrIfExists<int>("out_dtype");
if (in_dtype == out_dtype) { return in_dtype == out_dtype;
return true; } else if (op_type == "c_identity") {
}
}
if (x->Op()->Type() == "c_identity") {
return true; return true;
} else if (op_type == "assign") {
const auto& in_name = x->Op()->Input("X")[0];
const auto& out_name = x->Op()->Output("Out")[0];
return in_name == out_name;
} else if (op_type == "concat") {
return x->Op()->Input("X").size() == 1;
} }
// you can add more cases here. // you can add more cases here.
return false; return false;
}); });
auto useless_op_out = detector.mutable_pattern()->NewNode("useless_op_out");
auto* useless_op_out =
pattern->NewNode(useless_op_out_repr())->assert_is_var();
useless_op->LinksFrom({useless_op_in}).LinksTo({useless_op_out}); useless_op->LinksFrom({useless_op_in}).LinksTo({useless_op_out});
}
int found_subgraph_count = 0; } // namespace patterns
void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::FindUselessOpPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_count = 0;
GraphPatternDetector::handle_t handler = GraphPatternDetector::handle_t handler =
[&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* useless_op_var = subgraph.at(useless_op); GET_IR_NODE_FROM_SUBGRAPH(useless_op_in, useless_op_in, pattern);
Node* useless_op_in_var = subgraph.at(useless_op_in); GET_IR_NODE_FROM_SUBGRAPH(useless_op, useless_op, pattern);
Node* useless_op_out_var = subgraph.at(useless_op_out); GET_IR_NODE_FROM_SUBGRAPH(useless_op_out, useless_op_out, pattern);
const std::string useless_op_in_name = useless_op_in_var->Name(); CHECK_EQ(useless_op_in->IsVar(), true);
const std::string useless_op_out_name = useless_op_out_var->Name(); CHECK_EQ(useless_op_out->IsVar(), true);
// Remove links in graph CHECK_EQ(useless_op->IsOp(), true);
GraphSafeRemoveNodes(graph, {useless_op_in_var, useless_op_var});
// Modify pre_op_desc for (auto* prev_op : useless_op_in->inputs) {
// Link pre_op directly to scale_out CHECK_EQ(prev_op->IsOp(), true);
for (auto& node : graph->Nodes()) { prev_op->Op()->RenameOutput(useless_op_in->Var()->Name(),
if (node->IsOp()) { useless_op_out->Var()->Name());
auto* op_desc = node->Op(); IR_NODE_LINK_TO(prev_op, useless_op_out);
auto out_vars_map = op_desc->Outputs();
for (auto out_var_map : out_vars_map) {
auto names = out_var_map.second;
bool reset = false;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == useless_op_in_name) {
reset = true;
names[i] = useless_op_out_name;
break;
}
}
if (reset) {
op_desc->SetOutput(out_var_map.first, names);
op_desc->Flush();
IR_NODE_LINK_TO(node, useless_op_out_var);
break;
}
}
}
} }
found_subgraph_count++;
GraphSafeRemoveNodes(graph, {useless_op_in, useless_op});
found_count++;
}; };
detector(graph, handler); gpd(graph, handler);
AddStatis(found_subgraph_count); AddStatis(found_count);
} }
} // namespace ir } // namespace ir
......
...@@ -27,7 +27,7 @@ class IdentityOpCleanPass : public FusePassBase { ...@@ -27,7 +27,7 @@ class IdentityOpCleanPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
private: private:
virtual ~IdentityOpCleanPass() = default; const std::string name_scope_{"identity_op_clean_pass"};
}; };
} // namespace ir } // namespace ir
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(identity_op_clean_pass, assign) {
ProgramDesc program;
auto* x_var = program.MutableBlock(0)->Var("assign_x");
auto* out_var = program.MutableBlock(0)->Var("assign_out");
out_var->SetName(x_var->Name());
OpDesc* assign_op = program.MutableBlock(0)->AppendOp();
assign_op->SetType("assign");
assign_op->SetInput("X", {x_var->Name()});
assign_op->SetOutput("Out", {out_var->Name()});
std::unique_ptr<Graph> graph(new Graph(program));
auto pass = PassRegistry::Instance().Get("identity_op_clean_pass");
graph.reset(pass->Apply(graph.release()));
int assign_num = GetNumOpNodes(graph, "assign");
PADDLE_ENFORCE_EQ(
assign_num,
0,
platform::errors::PreconditionNotMet(
"graph should have 0 assign after identity_op_clean_pass, "
"but actually has %d.",
assign_num));
}
TEST(identity_op_clean_pass, scale) {
ProgramDesc program;
auto* x_var = program.MutableBlock(0)->Var("scale_x");
auto* out_var = program.MutableBlock(0)->Var("scale_out");
OpDesc* scale_op = program.MutableBlock(0)->AppendOp();
scale_op->SetType("scale");
scale_op->SetInput("X", {x_var->Name()});
scale_op->SetOutput("Out", {out_var->Name()});
scale_op->SetAttr("scale", 1.f);
scale_op->SetAttr("bias", 0.f);
std::unique_ptr<Graph> graph(new Graph(program));
auto pass = PassRegistry::Instance().Get("identity_op_clean_pass");
graph.reset(pass->Apply(graph.release()));
int scale_num = GetNumOpNodes(graph, "scale");
PADDLE_ENFORCE_EQ(
scale_num,
0,
platform::errors::PreconditionNotMet(
"graph should have 0 scale op after identity_op_clean_pass, "
"but actually has %d.",
scale_num));
}
TEST(identity_op_clean_pass, cast) {
ProgramDesc program;
auto* x_var = program.MutableBlock(0)->Var("cast_x");
auto* out_var = program.MutableBlock(0)->Var("cast_out");
OpDesc* cast_op = program.MutableBlock(0)->AppendOp();
cast_op->SetType("cast");
cast_op->SetInput("X", {x_var->Name()});
cast_op->SetOutput("Out", {out_var->Name()});
cast_op->SetAttr("in_dtype", 5);
cast_op->SetAttr("out_dtype", 5);
std::unique_ptr<Graph> graph(new Graph(program));
auto pass = PassRegistry::Instance().Get("identity_op_clean_pass");
graph.reset(pass->Apply(graph.release()));
int cast_num = GetNumOpNodes(graph, "cast");
PADDLE_ENFORCE_EQ(
cast_num,
0,
platform::errors::PreconditionNotMet(
"graph should have 0 cast after identity_op_clean_pass, "
"but actually has %d.",
cast_num));
}
TEST(identity_op_clean_pass, concat) {
ProgramDesc program;
auto* x_var = program.MutableBlock(0)->Var("concat_x");
auto* out_var = program.MutableBlock(0)->Var("concat_out");
OpDesc* concat_op = program.MutableBlock(0)->AppendOp();
concat_op->SetType("concat");
concat_op->SetInput("X", {x_var->Name()});
concat_op->SetOutput("Out", {out_var->Name()});
std::unique_ptr<Graph> graph(new Graph(program));
auto pass = PassRegistry::Instance().Get("identity_op_clean_pass");
graph.reset(pass->Apply(graph.release()));
int concat_num = GetNumOpNodes(graph, "concat");
PADDLE_ENFORCE_EQ(
concat_num,
0,
platform::errors::PreconditionNotMet(
"graph should have 0 concat after identity_op_clean_pass, "
"but actually has %d.",
concat_num));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(identity_op_clean_pass);
...@@ -221,7 +221,6 @@ const std::vector<std::string> kCINNCompilerPasses{ ...@@ -221,7 +221,6 @@ const std::vector<std::string> kCINNCompilerPasses{
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"map_op_to_another_pass", // "map_op_to_another_pass", //
"identity_op_clean_pass", //
"is_test_pass", // "is_test_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"delete_quant_dequant_linear_op_pass", // "delete_quant_dequant_linear_op_pass", //
...@@ -262,6 +261,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -262,6 +261,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
#endif // #endif //
"transpose_flatten_concat_fuse_pass", // "transpose_flatten_concat_fuse_pass", //
"identity_op_clean_pass", //
"conv2d_fusion_layout_transfer_pass", // "conv2d_fusion_layout_transfer_pass", //
"transfer_layout_elim_pass", "transfer_layout_elim_pass",
"auto_mixed_precision_pass", // "auto_mixed_precision_pass", //
......
...@@ -20,6 +20,54 @@ from program_config import OpConfig, ProgramConfig, TensorConfig ...@@ -20,6 +20,54 @@ from program_config import OpConfig, ProgramConfig, TensorConfig
class TestIdentityScaleCleanPass(PassAutoScanTest): class TestIdentityScaleCleanPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=True)
yield config, ['relu', 'relu', 'scale'], (1e-5, 1e-5)
def sample_program_config(self, draw):
bias_after_scale = draw(st.booleans())
n = draw(st.integers(min_value=1, max_value=4))
c = draw(st.integers(min_value=1, max_value=20))
h = draw(st.integers(min_value=1, max_value=20))
w = draw(st.integers(min_value=1, max_value=20))
relu_op1 = OpConfig(
"relu", inputs={"X": ["relu_x"]}, outputs={"Out": ["relu_op1_out"]}
)
scale_op1 = OpConfig(
"scale",
inputs={"X": ["relu_op1_out"]},
outputs={"Out": ["scale_op1_out"]},
bias=0.0,
scale=1.0,
bias_after_scale=True,
)
scale_op2 = OpConfig(
"scale",
inputs={"X": ["scale_op1_out"]},
outputs={"Out": ["scale_op2_out"]},
bias=0.0,
scale=1.0,
bias_after_scale=True,
)
relu_op2 = OpConfig(
"relu",
inputs={"X": ["relu_op1_out"]},
outputs={"Out": ["relu_op2_out"]},
)
program_config = ProgramConfig(
ops=[relu_op1, relu_op2, scale_op1, scale_op2],
weights={},
inputs={"relu_x": TensorConfig(shape=[n, c, h, w])},
outputs=["scale_op2_out", "relu_op2_out"],
)
return program_config
def test(self):
self.run_and_statis(max_examples=25, passes=["identity_op_clean_pass"])
class TestIdentityScaleCleanPass_V1(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=True) config = self.create_inference_config(use_gpu=True)
yield config, ['relu'], (1e-5, 1e-5) yield config, ['relu'], (1e-5, 1e-5)
...@@ -31,22 +79,75 @@ class TestIdentityScaleCleanPass(PassAutoScanTest): ...@@ -31,22 +79,75 @@ class TestIdentityScaleCleanPass(PassAutoScanTest):
h = draw(st.integers(min_value=1, max_value=20)) h = draw(st.integers(min_value=1, max_value=20))
w = draw(st.integers(min_value=1, max_value=20)) w = draw(st.integers(min_value=1, max_value=20))
relu_op = OpConfig( relu_op1 = OpConfig(
"relu", inputs={"X": ["relu_x"]}, outputs={"Out": ["relu_out"]} "relu", inputs={"X": ["relu_x"]}, outputs={"Out": ["relu_op1_out"]}
) )
scale_op = OpConfig( scale_op1 = OpConfig(
"scale", "scale",
inputs={"X": ["relu_out"]}, inputs={"X": ["relu_op1_out"]},
outputs={"Out": ["scale_out"]}, outputs={"Out": ["scale_op1_out"]},
bias=0.0,
scale=1.0,
bias_after_scale=True,
)
scale_op2 = OpConfig(
"scale",
inputs={"X": ["scale_op1_out"]},
outputs={"Out": ["scale_op2_out"]},
bias=0.0, bias=0.0,
scale=1.0, scale=1.0,
bias_after_scale=True, bias_after_scale=True,
) )
program_config = ProgramConfig( program_config = ProgramConfig(
ops=[relu_op, scale_op], ops=[relu_op1, scale_op1, scale_op2],
weights={}, weights={},
inputs={"relu_x": TensorConfig(shape=[n, c, h, w])}, inputs={"relu_x": TensorConfig(shape=[n, c, h, w])},
outputs=["scale_out"], outputs=["scale_op2_out"],
)
return program_config
def test(self):
self.run_and_statis(max_examples=25, passes=["identity_op_clean_pass"])
class TestIdentityScaleCleanPass_V2(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=True)
yield config, ['scale', 'relu'], (1e-5, 1e-5)
def sample_program_config(self, draw):
bias_after_scale = draw(st.booleans())
n = draw(st.integers(min_value=1, max_value=4))
c = draw(st.integers(min_value=1, max_value=20))
h = draw(st.integers(min_value=1, max_value=20))
w = draw(st.integers(min_value=1, max_value=20))
scale_op1 = OpConfig(
"scale",
inputs={"X": ["scale_op1_in"]},
outputs={"Out": ["scale_op1_out"]},
bias=0.0,
scale=1.0,
bias_after_scale=True,
)
scale_op2 = OpConfig(
"scale",
inputs={"X": ["scale_op1_out"]},
outputs={"Out": ["scale_op2_out"]},
bias=0.0,
scale=1.0,
bias_after_scale=True,
)
relu_op1 = OpConfig(
"relu",
inputs={"X": ["scale_op2_out"]},
outputs={"Out": ["relu_op1_out"]},
)
program_config = ProgramConfig(
ops=[scale_op1, scale_op2, relu_op1],
weights={},
inputs={"scale_op1_in": TensorConfig(shape=[n, c, h, w])},
outputs=["relu_op1_out"],
) )
return program_config return program_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册