未验证 提交 aec4e38f 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT] Del 2 useless pass (#53414)

* delete delete_fill_constant_op_pass and unsqueeze2_eltwise_fuse_pass
上级 af2ad8d8
......@@ -104,7 +104,6 @@ pass_library(delete_dropout_op_pass inference)
pass_library(delete_concat_op_pass inference)
pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(auto_mixed_precision_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference)
......@@ -118,7 +117,6 @@ pass_library(fused_multi_transformer_encoder_pass inference)
pass_library(fused_multi_transformer_decoder_pass inference)
pass_library(fuse_multi_transformer_layer_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(yolo_box_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
......@@ -391,10 +389,6 @@ cc_test(
test_adaptive_pool2d_convert_global_pass
SRCS adaptive_pool2d_convert_global_pass_tester.cc
DEPS adaptive_pool2d_convert_global_pass)
cc_test(
test_unsqueeze2_eltwise_fuse_pass_cc
SRCS unsqueeze2_eltwise_fuse_pass_tester.cc
DEPS unsqueeze2_eltwise_fuse_pass)
cc_test(
test_generate_pass_cc
SRCS generate_pass_tester.cc
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_fill_constant_op_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
template <typename T>
void FillConstData(phi::DenseTensor* out_t, T value) {
auto output_data = out_t->mutable_data<T>(platform::CPUPlace());
for (int i = 0; i < out_t->numel(); i++) {
output_data[i] = value;
}
}
void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const {
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
// Not support
if (with_dynamic_shape) {
return;
}
FusePassBase::Init("delete_fill_constant_op_pass", graph);
GraphPatternDetector detector;
auto fill_constant_op =
detector.mutable_pattern()
->NewNode("fill_constant")
->assert_is_op("fill_constant")
->assert_is_not_op_input("ValueTensor")
->assert_is_not_op_input("str_value")
->assert_is_not_op_input("ShapeTensor")
->assert_is_not_op_input("ShapeTensorList")
->assert_more([&](Node* node) {
return node->Op()
->GetAttrIfExists<std::vector<int64_t>>("shape")
.size() == 1;
});
auto fill_constant_out =
detector.mutable_pattern()
->NewNode("fill_constant_out")
->assert_is_op_output("fill_constant")
->assert_more([](Node* x) { return x->outputs.size() == 1UL; });
auto next_op = detector.mutable_pattern()
->NewNode("next_op")
->assert_is_not_op_type("conditional_block")
->assert_is_not_op_type("while");
// Create the topological connections for the above pattern nodes.
fill_constant_op->LinksTo({fill_constant_out});
next_op->LinksFrom({fill_constant_out});
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
Node* fill_constant_op_node = subgraph.at(fill_constant_op);
Node* fill_constant_out_node = subgraph.at(fill_constant_out);
// Get fill_constant's attr
auto fill_constant = fill_constant_op_node->Op();
auto value = PADDLE_GET_CONST(float, fill_constant->GetAttr("value"));
auto shape =
PADDLE_GET_CONST(std::vector<int64_t>, fill_constant->GetAttr("shape"));
auto* scope = param_scope();
auto fill_constant_out_desc = fill_constant_out_node->Var();
fill_constant_out_desc->SetShape(shape);
fill_constant_out_desc->SetPersistable(true);
auto* fill_constant_out_tensor = scope->Var(fill_constant_out_desc->Name())
->GetMutable<phi::DenseTensor>();
auto dtype =
framework::TransToPhiDataType(fill_constant_out_desc->GetDataType());
fill_constant_out_tensor->Resize(phi::make_ddim(shape));
switch (dtype) {
case phi::DataType::BOOL:
FillConstData<bool>(fill_constant_out_tensor, static_cast<bool>(value));
break;
case phi::DataType::INT32:
FillConstData<int32_t>(fill_constant_out_tensor,
static_cast<int32_t>(value));
break;
case phi::DataType::INT64:
FillConstData<int64_t>(fill_constant_out_tensor,
static_cast<int64_t>(value));
break;
case phi::DataType::FLOAT32:
FillConstData<float>(fill_constant_out_tensor,
static_cast<float>(value));
break;
default:
LOG(WARNING) << "Unsupported dtype for fill_constant op: " << dtype;
return;
}
// Remove links in graph
GraphSafeRemoveNodes(graph, {fill_constant_op_node});
};
detector(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_fill_constant_op_pass,
paddle::framework::ir::DeleteFillConstantOpPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class DeleteFillConstantOpPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
virtual ~DeleteFillConstantOpPass() = default;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct UnsqueezeEltwise : public PatternBase {
UnsqueezeEltwise(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "unsqueeze2_eltwise_fuse_pass") {}
PDNode *operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(unsqz);
PATTERN_DECL_NODE(elementwise);
// declare variable node's name
PATTERN_DECL_NODE(eltwise_in_x);
PATTERN_DECL_NODE(unsqz_in);
PATTERN_DECL_NODE(unsqz_out);
PATTERN_DECL_NODE(eltwise_out);
};
PDNode *UnsqueezeEltwise::operator()(PDNode *x, PDNode *y) {
x->assert_is_op_input("elementwise_mul", "X");
y->assert_is_op_input("unsqueeze2", "X");
auto *unsqz = pattern->NewNode(unsqz_repr())->assert_is_op("unsqueeze2");
auto *unsqz_out = pattern->NewNode(unsqz_out_repr())
->assert_is_op_output("unsqueeze2", "Out")
->assert_is_op_input("elementwise_mul", "Y");
unsqz->LinksFrom({y}).LinksTo({unsqz_out});
auto *elementwise =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_mul");
auto *eltwise_out = pattern->NewNode(eltwise_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_mul");
elementwise->LinksFrom({x, unsqz_out}).LinksTo({eltwise_out});
return eltwise_out;
}
} // namespace patterns
UnsqueezeEltwiseFusePass::UnsqueezeEltwiseFusePass() {
AddOpCompat(OpCompat("unsqueeze2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("AxesTensor")
.IsOptional()
.IsTensor()
.End()
.AddInput("AxesTensorList")
.IsOptional()
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axes")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("elementwise_mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
// The attribute value is - 1 before fusion and 0 after fusion
.AddAttr("axis")
.IsIntIn({-1, 0})
.End();
}
void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("unsqueeze2_eltwise_fuse_pass", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("unsqueeze2_eltwise_fuse_pass/x")
->AsInput()
->assert_is_op_input("elementwise_mul", "X")
->assert_var_not_persistable();
auto *y = gpd.mutable_pattern()
->NewNode("unsqueeze2_eltwise_fuse_pass/y")
->AsInput()
->assert_is_op_input("unsqueeze2", "X")
->assert_var_not_persistable();
patterns::UnsqueezeEltwise fused_pattern(gpd.mutable_pattern(),
"unsqueeze2_eltwise_fuse_pass");
fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle UnsqueezeEltwise fuse";
GET_IR_NODE_FROM_SUBGRAPH(eltwise_op, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(unsqz_op, unsqz, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(unsqz_out, unsqz_out, fused_pattern);
size_t eltwise_in_x_rank = (subgraph.at(x)->Var()->GetShape()).size();
size_t unsqz_in_rank = (subgraph.at(y)->Var()->GetShape()).size();
std::vector<int> unsqz_op_axes =
PADDLE_GET_CONST(std::vector<int>, unsqz_op->Op()->GetAttr("axes"));
int eltwise_op_axis =
PADDLE_GET_CONST(int, eltwise_op->Op()->GetAttr("axis"));
if (eltwise_in_x_rank == 4 && unsqz_in_rank == 2 &&
unsqz_op_axes == std::vector<int>{2, 3} && eltwise_op_axis == -1) {
eltwise_op->Op()->SetAttr("axis", 0);
eltwise_op->Op()->SetInput("Y", {subgraph.at(y)->Name()});
IR_NODE_LINK_TO(subgraph.at(x), eltwise_op);
IR_NODE_LINK_TO(subgraph.at(y), eltwise_op);
IR_NODE_LINK_TO(eltwise_op, eltwise_out);
GraphSafeRemoveNodes(graph, {unsqz_op, unsqz_out});
found_subgraph_count++;
if (!IsCompat(*eltwise_op->Op())) {
LOG(WARNING) << "unsqueeze2_eltwise_fuse_pass op compat failed.";
return;
}
}
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(unsqueeze2_eltwise_fuse_pass,
paddle::framework::ir::UnsqueezeEltwiseFusePass);
REGISTER_PASS_CAPABILITY(unsqueeze2_eltwise_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("unsqueeze2", 0)
.LE("elementwise_mul", 1));
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
// |(rank 4) |(rank 2) |(rank 4) |(rank 2)
// | unsqueeze2(axes=[2,3]) | |
// | | fuse \ /
// |------elementwise_mul(axis=-1) -> elementwise_mul(axis=0)
// | |
// | |
//
// Notice:
// the rank of input is obtained from var_desc,
// it maybe change in runtime.
class UnsqueezeEltwiseFusePass : public FusePassBase {
public:
UnsqueezeEltwiseFusePass();
virtual ~UnsqueezeEltwiseFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(UnsqueezeEltwiseFusePass, basic) {
Layers layers;
auto* x = layers.data("x", {1, 92, 28, 28});
auto* y = layers.data("y", {1, 92});
std::vector<int> axes{2, 3};
auto* unsqz_out = layers.unsqueeze2(y, axes);
AttributeMap attrs;
attrs["axis"] = -1;
layers.elementwise_mul(x, unsqz_out, nullptr, &attrs);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("unsqueeze2_eltwise_fuse_pass");
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_fused_nodes_after = GetNumOpNodes(graph, "elementwise_mul");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before,
num_nodes_after + 2,
platform::errors::PreconditionNotMet(
"The number of nodes before and after the fuse does "
"not meet expectations"));
PADDLE_ENFORCE_EQ(
num_fused_nodes_after,
1,
platform::errors::PreconditionNotMet(
"The number of fusion nodes does not meet expectations after fuse"));
}
TEST(UnsqueezeEltwiseFusePass, pass_op_version_check) {
ASSERT_TRUE(
paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
.IsPassCompatible("unsqueeze2_eltwise_fuse_pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(unsqueeze2_eltwise_fuse_pass);
......@@ -90,7 +90,6 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_map_ops_to_matrix_multiply_pass", //
"shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", //
"delete_fill_constant_op_pass", //
"delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", //
"trt_delete_weight_dequant_linear_op_pass", //
......@@ -123,7 +122,6 @@ const std::vector<std::string> kTRTSubgraphPasses({
"preln_layernorm_x_fuse_pass", //
"reverse_roll_fuse_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
#if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading.
#else
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
import paddle.inference as paddle_infer
class TestUnsqueezeEltwiseFusePass(PassAutoScanTest):
r"""
y_var
|
unsqueeze2
\
unsqueeze2_out_var x_var
\ /
elementwise_mul
"""
def sample_predictor_configs(self, program_config):
# TRT
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=10,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False,
)
yield config, [
'elementwise_mul',
], (1e-5, 1e-5)
def sample_program_config(self, draw):
# 1. Generate shape and attr of mul
x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=10), min_size=4, max_size=4
)
)
axis = -1
# 2. Generate legal shape and attr of input:Y of unsqueeze2
y_shape = x_shape[:2]
unsqueeze2_axes = [2, 3]
unsqueeze2_op = OpConfig(
"unsqueeze2",
inputs={
"X": ["unsqueeze2_x"],
"AxesTensor": [],
"AxesTensorList": [],
},
axes=unsqueeze2_axes,
outputs={"Out": ["unsqueeze2_out"], "XShape": ["xshape"]},
)
mul_op = OpConfig(
"elementwise_mul",
inputs={"Y": ["unsqueeze2_out"], "X": ["mul_x"]},
axis=axis,
outputs={"Out": ["mul_out"]},
)
ops = [
unsqueeze2_op,
mul_op,
]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"mul_x": TensorConfig(shape=x_shape),
"unsqueeze2_x": TensorConfig(shape=y_shape),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=300,
passes=["unsqueeze2_eltwise_fuse_pass"],
)
if __name__ == "__main__":
unittest.main()
......@@ -131,7 +131,6 @@ HIGH_PARALLEL_JOB_NEW = [
'test_conv_concat_relu_mkldnn_fuse_pass',
'test_bf16_utils',
'test_sum_bf16_mkldnn_op',
'test_unsqueeze2_eltwise_fuse_pass_cc',
'dense_table_test',
'test_collective_optimizer',
'test_origin_info',
......@@ -2145,7 +2144,6 @@ CPU_PARALLEL_JOB = [
'test_recv_save_op',
'heter_listen_and_server_test',
'test_analyzer_capi_ner',
'test_unsqueeze2_eltwise_fuse_pass_cc',
'test_dgc_optimizer',
'heter_server_test',
'test_custom_conj',
......
......@@ -178,7 +178,6 @@ disable_win_inference_test="^trt_quant_int8_yolov3_r50_test$|\
^test_trt_convert_multihead_matmul$|\
^test_trt_convert_prelu$|\
^test_trt_fc_fuse_quant_dequant_pass$|\
^test_unsqueeze2_eltwise_fuse_pass$|\
^test_parallel_executor_seresnext_with_fuse_all_reduce_gpu$|\
^test_parallel_executor_seresnext_with_reduce_gpu$|\
^test_api_impl$|\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册