未验证 提交 24efc13b 编写于 作者: 周周周 提交者: GitHub

[Paddle Inference] remove redunant op (#54442)

* remove redunant
上级 a0d59f9d
......@@ -89,7 +89,7 @@ pass_library(conv_elementwise_add2_act_fuse_pass inference)
pass_library(conv_elementwise_add_fuse_pass inference)
pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(inplace_op_var_pass inference)
pass_library(identity_scale_op_clean_pass base)
pass_library(identity_op_clean_pass base)
pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference)
......@@ -102,7 +102,6 @@ pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference)
pass_library(delete_concat_op_pass inference)
pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(constant_folding_pass inference)
pass_library(auto_mixed_precision_pass inference)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_c_identity_op_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
void DeleteCIdentityOpPattern::operator()() {
auto any_op_out = pattern->NewNode(any_op_out_repr())
->assert_is_op_input("c_identity", "X")
->AsInput();
auto c_identity_op =
pattern->NewNode(c_identity_op_repr())->assert_is_op("c_identity");
auto c_identity_op_out = pattern->NewNode(c_identity_op_out_repr())
->assert_is_op_output("c_identity", "Out")
->AsIntermediate();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
c_identity_op->LinksFrom({any_op_out});
c_identity_op_out->LinksFrom({c_identity_op});
any_op2->LinksFrom({c_identity_op_out});
}
} // namespace patterns
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(c_identity_op); \
GET_IR_NODE(c_identity_op_out); \
GET_IR_NODE(any_op2);
void DeleteCIdentityOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_c_identity_op_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
patterns::DeleteCIdentityOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
IR_NODE_LINK_TO(any_op_out, any_op2);
std::string any_op_out_name = any_op_out->Var()->Name();
std::string c_identity_op_out_name = c_identity_op_out->Var()->Name();
auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(),
name_m.second.end(),
c_identity_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
}
}
if (arg_name.size() == 0) {
LOG(INFO) << "Delete c_identity op pass: can not find the input "
<< c_identity_op_out_name;
return;
}
// modify the any_op2's inputs
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(),
name_m.second.end(),
c_identity_op_out_name) != name_m.second.end()) {
std::vector<std::string> new_inputs;
for (auto& i_n : name_m.second) {
if (i_n != c_identity_op_out_name) {
new_inputs.push_back(i_n);
}
}
new_inputs.push_back(any_op_out_name);
any_op2_desc->SetInput(name_m.first, new_inputs);
any_op2_desc->Flush();
}
}
any_op2_desc->Flush();
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, {c_identity_op, c_identity_op_out});
};
gpd(graph, handler);
}
DeleteCIdentityOpPass::DeleteCIdentityOpPass() {
AddOpCompat(OpCompat("c_identity"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_c_identity_op_pass,
paddle::framework::ir::DeleteCIdentityOpPass);
REGISTER_PASS_CAPABILITY(delete_c_identity_op_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"c_identity", 1));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct DeleteCIdentityOpPattern : public PatternBase {
DeleteCIdentityOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_c_identity_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(any_op_out);
PATTERN_DECL_NODE(c_identity_op);
PATTERN_DECL_NODE(c_identity_op_out);
PATTERN_DECL_NODE(any_op2);
};
} // namespace patterns
class Graph;
class DeleteCIdentityOpPass : public FusePassBase {
public:
DeleteCIdentityOpPass();
virtual ~DeleteCIdentityOpPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/identity_scale_op_clean_pass.h"
#include "paddle/fluid/framework/ir/identity_op_clean_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
......@@ -23,17 +23,16 @@ namespace ir {
class Graph;
void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("identity_scale_op_clean", graph);
// pre_op -> scale_in -> scale_op -> scale_out
// pre_op -> useless_op_in -> useless_op -> useless_op_out
// ->
// pre_op -> scale_out
// pre_op -> useless_op_out
GraphPatternDetector detector;
auto scale_in =
auto useless_op_in =
detector.mutable_pattern()
->NewNode("scale_in")
->assert_is_op_input("scale")
->NewNode("useless_op_in")
->assert_has_n_outputs(1)
->assert_var_not_persistable()
->assert_more([](Node* x) {
......@@ -45,27 +44,51 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
}
return true;
});
auto scale_op = detector.mutable_pattern()
->NewNode("scale_fuse")
->assert_is_op("scale")
->assert_op_attr<float>("scale", 1.)
->assert_op_attr<float>("bias", 0.);
auto scale_out = detector.mutable_pattern()
->NewNode("scale_out")
->assert_is_op_output("scale");
scale_op->LinksFrom({scale_in}).LinksTo({scale_out});
// This useless_op must have only one input and one output!
auto useless_op =
detector.mutable_pattern()
->NewNode("useless_op")
->assert_has_n_inputs(1)
->assert_has_n_outputs(1)
->assert_more([](Node* x) {
if (!x->IsOp()) {
return false;
}
if (x->Op()->Type() == "scale") {
auto scale = x->Op()->GetAttrIfExists<float>("scale");
auto bias = x->Op()->GetAttrIfExists<float>("bias");
if (std::abs(bias) <= 1e-6 && std::abs(scale - 1) <= 1e-6) {
return true;
}
}
if (x->Op()->Type() == "cast") {
auto in_dtype = x->Op()->GetAttrIfExists<int>("in_dtype");
auto out_dtype = x->Op()->GetAttrIfExists<int>("out_dtype");
if (in_dtype == out_dtype) {
return true;
}
}
if (x->Op()->Type() == "c_identity") {
return true;
}
// you can add more cases here.
return false;
});
auto useless_op_out = detector.mutable_pattern()->NewNode("useless_op_out");
useless_op->LinksFrom({useless_op_in}).LinksTo({useless_op_out});
int found_subgraph_count = 0;
GraphPatternDetector::handle_t handler =
[&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* scale_op_var = subgraph.at(scale_op);
Node* scale_in_var = subgraph.at(scale_in);
Node* scale_out_var = subgraph.at(scale_out);
const std::string scale_in_name = scale_in_var->Name();
const std::string scale_out_name = scale_out_var->Name();
Node* useless_op_var = subgraph.at(useless_op);
Node* useless_op_in_var = subgraph.at(useless_op_in);
Node* useless_op_out_var = subgraph.at(useless_op_out);
const std::string useless_op_in_name = useless_op_in_var->Name();
const std::string useless_op_out_name = useless_op_out_var->Name();
// Remove links in graph
GraphSafeRemoveNodes(graph, {scale_in_var, scale_op_var});
GraphSafeRemoveNodes(graph, {useless_op_in_var, useless_op_var});
// Modify pre_op_desc
// Link pre_op directly to scale_out
for (auto& node : graph->Nodes()) {
......@@ -76,16 +99,16 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
auto names = out_var_map.second;
bool reset = false;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == scale_in_name) {
if (names[i] == useless_op_in_name) {
reset = true;
names[i] = scale_out_name;
names[i] = useless_op_out_name;
break;
}
}
if (reset) {
op_desc->SetOutput(out_var_map.first, names);
op_desc->Flush();
IR_NODE_LINK_TO(node, scale_out_var);
IR_NODE_LINK_TO(node, useless_op_out_var);
break;
}
}
......@@ -102,9 +125,10 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
} // namespace framework
} // namespace paddle
REGISTER_PASS(identity_scale_op_clean_pass,
paddle::framework::ir::IdentityScaleOpCleanPass);
REGISTER_PASS_CAPABILITY(identity_scale_op_clean_pass)
REGISTER_PASS(identity_op_clean_pass,
paddle::framework::ir::IdentityOpCleanPass);
REGISTER_PASS_CAPABILITY(identity_op_clean_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"scale", 0));
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("scale", 0)
.LE("c_identity", 1));
......@@ -22,12 +22,12 @@ namespace ir {
class Graph;
class IdentityScaleOpCleanPass : public FusePassBase {
class IdentityOpCleanPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
virtual ~IdentityScaleOpCleanPass() = default;
virtual ~IdentityOpCleanPass() = default;
};
} // namespace ir
......
......@@ -54,7 +54,7 @@ static const std::vector<std::string> support_subgraph_passes = {
static const std::vector<std::string> xpu_support_subgraph_passes = {
"delete_dropout_op_pass",
"delete_concat_op_pass",
"identity_scale_op_clean_pass",
"identity_op_clean_pass",
"delete_op_device_pass",
"constant_folding_pass",
"delete_elementwise_mul_op_pass",
......
......@@ -94,12 +94,11 @@ const std::vector<std::string> kTRTSubgraphPasses({
"delete_quant_dequant_filter_op_pass", //
"trt_delete_weight_dequant_linear_op_pass", //
"delete_quant_dequant_linear_op_pass", //
"identity_scale_op_clean_pass", //
"identity_op_clean_pass", //
"add_support_int8_pass", //
"simplify_with_basic_ops_pass", //
"trt_embedding_eltwise_layernorm_fuse_pass", //
"preln_embedding_eltwise_layernorm_fuse_pass", //
"delete_c_identity_op_pass", //
"trt_multihead_matmul_fuse_pass_v2", //
"trt_multihead_matmul_fuse_pass_v3", //
"multihead_matmul_roformer_fuse_pass", //
......@@ -175,7 +174,7 @@ const std::vector<std::string> kLiteSubgraphPasses({
// running errors. After fusion operator supports low precision, delete this.
const std::vector<std::string> kGpuLowerPrecisionPasses{
"map_op_to_another_pass",
"identity_scale_op_clean_pass",
"identity_op_clean_pass",
"simplify_with_basic_ops_pass",
"silu_fuse_pass",
"delete_quant_dequant_linear_op_pass",
......@@ -222,7 +221,7 @@ const std::vector<std::string> kCINNCompilerPasses{
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"map_op_to_another_pass", //
"identity_scale_op_clean_pass", //
"identity_op_clean_pass", //
"is_test_pass", //
"simplify_with_basic_ops_pass", //
"delete_quant_dequant_linear_op_pass", //
......@@ -511,7 +510,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"delete_dropout_op_pass",
"delete_concat_op_pass",
"identity_scale_op_clean_pass",
"identity_op_clean_pass",
"delete_repeated_ops_pass",
"delete_op_device_pass",
"constant_folding_pass",
......
......@@ -56,7 +56,7 @@ class TestDeleteCIdentityPass(PassAutoScanTest):
self.run_and_statis(
max_examples=2,
min_success_num=2,
passes=["delete_c_identity_op_pass"],
passes=["identity_op_clean_pass"],
)
......
......@@ -18,20 +18,10 @@ import hypothesis.strategies as st
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
import paddle.inference as paddle_infer
class TestIdentityScaleCleanPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=8,
workspace_size=0,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False,
)
config = self.create_inference_config(use_gpu=True)
yield config, ['relu'], (1e-5, 1e-5)
def sample_program_config(self, draw):
......@@ -61,9 +51,7 @@ class TestIdentityScaleCleanPass(PassAutoScanTest):
return program_config
def test(self):
self.run_and_statis(
max_examples=25, passes=["identity_scale_op_clean_pass"]
)
self.run_and_statis(max_examples=25, passes=["identity_op_clean_pass"])
if __name__ == "__main__":
......
......@@ -156,7 +156,7 @@ class TrtConvertRangeStaticTest(TrtLayerAutoScanTest):
def generate_input2():
return np.array([1]).astype(np.int32)
for in_dtype in [2, 5]:
for in_dtype in [2]:
self.in_dtype = in_dtype
dics = [{}]
ops_config = [
......
......@@ -102,7 +102,7 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest):
class TestDeleteRepeatedSlicePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ['slice', 'cast', 'cast', 'cast'], (1e-5, 1e-5)
yield config, ['slice'], (1e-5, 1e-5)
def sample_program_config(self, draw):
slice_x = draw(
......@@ -122,15 +122,6 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
decrease_axis=[0],
outputs={"Out": ["slice0_out"]},
)
cast_op0 = OpConfig(
"cast",
inputs={
"X": ["slice0_out"],
},
in_dtype=5,
out_dtype=5,
outputs={"Out": ["cast0_out"]},
)
slice_op1 = OpConfig(
"slice",
inputs={
......@@ -142,15 +133,6 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
decrease_axis=[0],
outputs={"Out": ["slice1_out"]},
)
cast_op1 = OpConfig(
"cast",
inputs={
"X": ["slice1_out"],
},
in_dtype=5,
out_dtype=5,
outputs={"Out": ["cast1_out"]},
)
slice_op2 = OpConfig(
"slice",
inputs={
......@@ -162,16 +144,7 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
decrease_axis=[0],
outputs={"Out": ["slice2_out"]},
)
cast_op2 = OpConfig(
"cast",
inputs={
"X": ["slice2_out"],
},
in_dtype=5,
out_dtype=5,
outputs={"Out": ["cast2_out"]},
)
ops = [slice_op0, cast_op0, slice_op1, cast_op1, slice_op2, cast_op2]
ops = [slice_op0, slice_op1, slice_op2]
program_config = ProgramConfig(
ops=ops,
......@@ -179,7 +152,7 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
inputs={
"slice_x": TensorConfig(shape=slice_x),
},
outputs=["cast0_out", "cast1_out", "cast2_out"],
outputs=["slice0_out", "slice1_out", "slice2_out"],
)
return program_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册