未验证 提交 9dd85b6b 编写于 作者: C csy0225 提交者: GitHub

[Inference][XPU] Add transpose + matmul fuse pass. (#55770)

上级 352172ac
......@@ -81,6 +81,49 @@ Reshape2MatmulV2Pattern::Reshape2MatmulV2Pattern(PDPattern* pattern,
reshape2->LinksFrom({reshape2_in}).LinksTo({matmul_x});
matmul_v2->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out});
}
struct Transpose2MatmulV2Pattern : public PatternBase {
Transpose2MatmulV2Pattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(transpose2);
PATTERN_DECL_NODE(matmul);
// declare variable node's name
PATTERN_DECL_NODE(transpose2_in);
PATTERN_DECL_NODE(matmul_x);
PATTERN_DECL_NODE(matmul_y);
PATTERN_DECL_NODE(matmul_out);
};
Transpose2MatmulV2Pattern::Transpose2MatmulV2Pattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* transpose2_in = pattern->NewNode(transpose2_in_repr())
->assert_is_op_input("transpose2", "X")
->AsInput();
auto* transpose2 =
pattern->NewNode(transpose2_repr())
->assert_is_op("transpose2")
->assert_more([](Node* node) {
auto axis = node->Op()->GetAttrIfExists<std::vector<int>>("axis");
return axis.size() == 3 && axis[0] == 0 && axis[1] == 2 &&
axis[2] == 1; // axis == [0, 2, 1]
});
auto matmul_x =
pattern->NewNode(matmul_x_repr())->assert_is_op_input("matmul_v2", "X");
auto* matmul_y = pattern->NewNode(matmul_y_repr())
->assert_is_op_input("matmul_v2", "Y")
->assert_is_op_output("transpose2", "Out");
auto* matmul = pattern->NewNode(matmul_repr())
->assert_is_op("matmul_v2")
->assert_op_attr<bool>("trans_y", false);
auto* matmul_out = pattern->NewNode(matmul_out_repr())
->assert_is_op_output("matmul_v2", "Out")
->AsOutput();
transpose2->LinksFrom({transpose2_in}).LinksTo({matmul_y});
matmul->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out});
}
} // namespace patterns
void MatmulWeightTransPass::TransMatmulV2Weight(ir::Graph* graph) const {
......@@ -119,12 +162,51 @@ void MatmulWeightTransPass::TransMatmulV2Weight(ir::Graph* graph) const {
AddStatis(found_subgraph_count);
}
void MatmulWeightTransPass::FuseTranspose2MatmulV2(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::Transpose2MatmulV2Pattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FuseTranspose2Matmul";
/* declare operator node's name */
GET_IR_NODE(transpose2);
GET_IR_NODE(matmul);
/* declare variable node's name*/
GET_IR_NODE(transpose2_in);
GET_IR_NODE(matmul_x);
GET_IR_NODE(matmul_y);
GET_IR_NODE(matmul_out);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
matmul->Op()->RenameInput(matmul_y->Name(), transpose2_in->Name());
matmul->Op()->SetAttr("trans_y", true);
matmul->Op()->Flush();
IR_NODE_LINK_TO(transpose2_in, matmul);
// delete useless node
std::unordered_set<const Node*> delete_nodes = {transpose2, matmul_y};
GraphSafeRemoveNodes(graph, delete_nodes);
// delete useless node
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void MatmulWeightTransPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
TransMatmulV2Weight(graph);
FuseTranspose2MatmulV2(graph);
}
} // namespace ir
......
......@@ -51,6 +51,7 @@ class MatmulWeightTransPass : public FusePassBase {
private:
void TransMatmulV2Weight(ir::Graph* graph) const;
void FuseTranspose2MatmulV2(ir::Graph* graph) const;
const std::string name_scope_{"matmul_weight_trans_pass"};
};
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestXpuMatmulV2WeightTransPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
# cpu
config = self.create_inference_config(use_xpu=True)
yield config, [
"matmul_v2",
], (1e-3, 1e-3)
def sample_program_config(self, draw):
# 1. Generate shape and attr of matmul
x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=8), min_size=3, max_size=3
)
)
transpose_shape = x_shape
transpose_op = OpConfig(
"transpose2",
inputs={"X": ["transpose_input"]},
outputs={"Out": ["transpose_out"]},
axis=[0, 2, 1],
)
matmul_op = OpConfig(
"matmul_v2",
inputs={"X": ["matmul_x"], "Y": ["transpose_out"]},
outputs={"Out": ["matmul_out"]},
transpose_X=False,
transpose_Y=False,
)
ops = [transpose_op, matmul_op]
weights = {}
inputs = {
"matmul_x": TensorConfig(shape=x_shape),
"transpose_input": TensorConfig(shape=transpose_shape),
}
program_config = ProgramConfig(
ops=ops,
weights=weights,
inputs=inputs,
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
min_success_num=5,
passes=["matmul_weight_trans_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册