未验证 提交 54a101d5 编写于 作者: W wz1qqx 提交者: GitHub

[XPU] add reduce_max_fuse_pass (#54981)

上级 97e87d2d
......@@ -271,6 +271,7 @@ if(WITH_XPU)
${XPU_PASS_DEPS})
pass_library(fold_two_squeeze2_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(reduce_max_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(matmul_weight_trans_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(reshape2_matmul_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/reduce_max_fuse_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct ReduceMaxFusePattern : public PatternBase {
ReduceMaxFusePattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(unsqueeze2);
PATTERN_DECL_NODE(pool2d);
PATTERN_DECL_NODE(squeeze2);
PATTERN_DECL_NODE(transpose2_2);
// declare variable node's name
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(unsqueeze2_out);
PATTERN_DECL_NODE(pool2d_out);
PATTERN_DECL_NODE(squeeze2_out);
PATTERN_DECL_NODE(transpose2_2_out);
};
ReduceMaxFusePattern::ReduceMaxFusePattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* x = pattern->NewNode(x_repr())
->assert_is_op_input("transpose2", "X")
->assert_more([](Node* node) {
auto x_shape = node->Var()->GetShape();
size_t x_rank = x_shape.size();
return x_rank == 3;
});
auto* transpose2_1 =
pattern->NewNode(transpose2_1_repr())
->assert_is_op("transpose2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto axis_array =
op_desc->GetAttrIfExists<std::vector<int>>("axis");
return axis_array == std::vector<int>{0, 2, 1};
});
auto* transpose2_1_out = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("unsqueeze2", "X");
auto* unsqueeze2 =
pattern->NewNode(unsqueeze2_repr())
->assert_is_op("unsqueeze2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto axes_array =
op_desc->GetAttrIfExists<std::vector<int>>("axes");
return axes_array == std::vector<int>{2};
});
auto* unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr())
->assert_is_op_output("unsqueeze2", "Out")
->assert_is_op_input("pool2d", "X");
auto* pool2d =
pattern->NewNode(pool2d_repr())
->assert_is_op("pool2d")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto input_var = node->inputs[0]->Var();
auto pool2d_x_shape = input_var->GetShape();
std::vector<int> HW = {static_cast<int>(pool2d_x_shape[2]),
static_cast<int>(pool2d_x_shape[3])};
auto pool_type =
op_desc->GetAttrIfExists<std::string>("pooling_type");
auto ksize_array =
op_desc->GetAttrIfExists<std::vector<int>>("ksize");
auto strides_array =
op_desc->GetAttrIfExists<std::vector<int>>("strides");
auto paddings_array =
op_desc->GetAttrIfExists<std::vector<int>>("paddings");
return pool_type == "max" && ksize_array == HW &&
strides_array == HW &&
paddings_array == std::vector<int>{0, 0};
});
auto* pool2d_out = pattern->NewNode(pool2d_out_repr())
->assert_is_op_output("pool2d", "Out")
->assert_is_op_input("squeeze2", "X");
auto* squeeze2 = pattern->NewNode(squeeze2_repr())
->assert_is_op("squeeze2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto axes_array =
op_desc->GetAttrIfExists<std::vector<int>>("axes");
return axes_array == std::vector<int>{2};
});
auto* squeeze2_out = pattern->NewNode(squeeze2_out_repr())
->assert_is_op_output("squeeze2", "Out")
->assert_is_op_input("transpose2", "X");
auto* transpose2_2 =
pattern->NewNode(transpose2_2_repr())
->assert_is_op("transpose2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto axis_array =
op_desc->GetAttrIfExists<std::vector<int>>("axis");
return axis_array == std::vector<int>{0, 2, 1};
});
auto* transpose2_2_out = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2", "Out");
transpose2_1->LinksFrom({x}).LinksTo({transpose2_1_out});
unsqueeze2->LinksFrom({transpose2_1_out}).LinksTo({unsqueeze2_out});
pool2d->LinksFrom({unsqueeze2_out}).LinksTo({pool2d_out});
squeeze2->LinksFrom({pool2d_out}).LinksTo({squeeze2_out});
transpose2_2->LinksFrom({squeeze2_out}).LinksTo({transpose2_2_out});
}
} // namespace patterns
void ReduceMaxFusePass::FuseReduceMax(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::ReduceMaxFusePattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ReduceMaxFusePass";
// declare operator node's name
GET_IR_NODE(x);
GET_IR_NODE(transpose2_1);
GET_IR_NODE(unsqueeze2);
GET_IR_NODE(pool2d);
GET_IR_NODE(squeeze2);
GET_IR_NODE(transpose2_2);
// declare variable node's name
GET_IR_NODE(transpose2_1_out);
GET_IR_NODE(unsqueeze2_out);
GET_IR_NODE(pool2d_out);
GET_IR_NODE(squeeze2_out);
GET_IR_NODE(transpose2_2_out);
auto* block = transpose2_1->Op()->Block();
// Generate reshape2 op
framework::OpDesc reduce_op_desc(block);
reduce_op_desc.SetType("reduce_max");
reduce_op_desc.SetInput("X", {x->Name()});
reduce_op_desc.SetAttr("dim", std::vector<int>{-2});
reduce_op_desc.SetAttr("reduce_all", false);
reduce_op_desc.SetAttr("keep_dim", true);
reduce_op_desc.SetOutput("Out", {transpose2_2_out->Name()});
auto* reduce = graph->CreateOpNode(&reduce_op_desc);
IR_NODE_LINK_TO(x, reduce);
IR_NODE_LINK_TO(reduce, transpose2_2_out);
// delete useless node
std::unordered_set<const Node*> delete_nodes = {transpose2_1,
transpose2_1_out,
unsqueeze2,
unsqueeze2_out,
pool2d,
pool2d_out,
squeeze2,
squeeze2_out,
transpose2_2};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void ReduceMaxFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
FuseReduceMax(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reduce_max_fuse_pass, paddle::framework::ir::ReduceMaxFusePass);
REGISTER_PASS_CAPABILITY(reduce_max_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"reduce_max", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
/*
fuse series small ops to reduce_max op
For example:
graph:
x
|
transpose2
|
unsqueeze2
|
pool2d(pooling_type : max)
|
squeeze2
|
transpose2
|
------------------------------------------------------
After the pass is applied:
x
|
reduce_max
|
*/
class ReduceMaxFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
void FuseReduceMax(ir::Graph* graph) const;
const std::string name_scope_{"reduce_max_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -526,6 +526,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"one_beam_size_fuse_pass",
"fold_interp_outsize_fuse_pass",
"fold_two_squeeze2_fuse_pass",
"reduce_max_fuse_pass",
"delete_cast_op_pass",
"xpu_delete_cast_op_pass",
"stack_fuse_pass",
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestFcFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["reduce_max"], (1e-3, 1e-3)
def sample_program_config(self, draw):
s_axes = [2]
batch_size = draw(st.integers(min_value=1, max_value=4))
H = draw(st.integers(min_value=1, max_value=64))
W = draw(st.integers(min_value=1, max_value=64))
in_shape = [batch_size, H, W]
transpose_op1 = OpConfig(
type='transpose2',
inputs={
"X": ["transpose_in"],
},
outputs={"Out": ["transpose_out1"]},
attrs={"axis": [0, 2, 1]},
)
unsqueeze2_op = OpConfig(
type="unsqueeze2",
inputs={"X": ["transpose_out1"]},
outputs={"Out": ["unsqueeze_out"]},
attrs={
"axes": s_axes,
},
)
pool_op = OpConfig(
"pool2d",
inputs={"X": ["unsqueeze_out"]},
outputs={"Out": ["pool_out"]},
ksize=[1, H],
adaptive=False,
pooling_type="max",
data_format="NCHW",
strides=[1, H],
paddings=[0, 0],
ceil_mode=False,
global_pooling=False,
padding_algorithm="EXPLICIT",
exclusive=True,
)
squeeze2_op = OpConfig(
"squeeze2",
inputs={
"X": ["pool_out"],
},
axes=s_axes,
outputs={"Out": ["squeeze2_out"], "XShape": ["xshape"]},
)
transpose_op2 = OpConfig(
type='transpose2',
inputs={
"X": ["squeeze2_out"],
},
outputs={"Out": ["transpose_out2"]},
attrs={"axis": [0, 2, 1]},
)
ops = [
transpose_op1,
unsqueeze2_op,
pool_op,
squeeze2_op,
transpose_op2,
]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"transpose_in": TensorConfig(shape=in_shape),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["reduce_max_fuse_pass"],
)
if __name__ == "__main__":
np.random.seed(200)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册