diff --git a/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.cc b/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.cc index 92a2634938bd3697712c1733a5449e8c24635238..414d166e4e3d6ddce66802ba9ea5f46a56f7cf88 100644 --- a/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.cc +++ b/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.cc @@ -81,6 +81,49 @@ Reshape2MatmulV2Pattern::Reshape2MatmulV2Pattern(PDPattern* pattern, reshape2->LinksFrom({reshape2_in}).LinksTo({matmul_x}); matmul_v2->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out}); } + +struct Transpose2MatmulV2Pattern : public PatternBase { + Transpose2MatmulV2Pattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(transpose2); + PATTERN_DECL_NODE(matmul); + // declare variable node's name + PATTERN_DECL_NODE(transpose2_in); + PATTERN_DECL_NODE(matmul_x); + PATTERN_DECL_NODE(matmul_y); + PATTERN_DECL_NODE(matmul_out); +}; + +Transpose2MatmulV2Pattern::Transpose2MatmulV2Pattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* transpose2_in = pattern->NewNode(transpose2_in_repr()) + ->assert_is_op_input("transpose2", "X") + ->AsInput(); + auto* transpose2 = + pattern->NewNode(transpose2_repr()) + ->assert_is_op("transpose2") + ->assert_more([](Node* node) { + auto axis = node->Op()->GetAttrIfExists>("axis"); + return axis.size() == 3 && axis[0] == 0 && axis[1] == 2 && + axis[2] == 1; // axis == [0, 2, 1] + }); + auto matmul_x = + pattern->NewNode(matmul_x_repr())->assert_is_op_input("matmul_v2", "X"); + auto* matmul_y = pattern->NewNode(matmul_y_repr()) + ->assert_is_op_input("matmul_v2", "Y") + ->assert_is_op_output("transpose2", "Out"); + auto* matmul = pattern->NewNode(matmul_repr()) + ->assert_is_op("matmul_v2") + ->assert_op_attr("trans_y", false); + auto* matmul_out = pattern->NewNode(matmul_out_repr()) + ->assert_is_op_output("matmul_v2", "Out") + ->AsOutput(); + transpose2->LinksFrom({transpose2_in}).LinksTo({matmul_y}); + matmul->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out}); +} + } // namespace patterns void MatmulWeightTransPass::TransMatmulV2Weight(ir::Graph* graph) const { @@ -119,12 +162,51 @@ void MatmulWeightTransPass::TransMatmulV2Weight(ir::Graph* graph) const { AddStatis(found_subgraph_count); } +void MatmulWeightTransPass::FuseTranspose2MatmulV2(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::Transpose2MatmulV2Pattern pattern(gpd.mutable_pattern(), + name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle FuseTranspose2Matmul"; + /* declare operator node's name */ + GET_IR_NODE(transpose2); + GET_IR_NODE(matmul); + /* declare variable node's name*/ + GET_IR_NODE(transpose2_in); + GET_IR_NODE(matmul_x); + GET_IR_NODE(matmul_y); + GET_IR_NODE(matmul_out); + + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + + matmul->Op()->RenameInput(matmul_y->Name(), transpose2_in->Name()); + matmul->Op()->SetAttr("trans_y", true); + matmul->Op()->Flush(); + + IR_NODE_LINK_TO(transpose2_in, matmul); + // delete useless node + std::unordered_set delete_nodes = {transpose2, matmul_y}; + GraphSafeRemoveNodes(graph, delete_nodes); + // delete useless node + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + void MatmulWeightTransPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); TransMatmulV2Weight(graph); + FuseTranspose2MatmulV2(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.h b/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.h index 2eeaeaaca3c89b5be156f6293d12bab11488db27..e54639140045db2077f446517d76c6095481f34c 100644 --- a/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.h +++ b/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.h @@ -51,6 +51,7 @@ class MatmulWeightTransPass : public FusePassBase { private: void TransMatmulV2Weight(ir::Graph* graph) const; + void FuseTranspose2MatmulV2(ir::Graph* graph) const; const std::string name_scope_{"matmul_weight_trans_pass"}; }; diff --git a/test/ir/inference/test_xpu_matmul_weight_trans_pass.py b/test/ir/inference/test_xpu_matmul_weight_trans_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3f89a8792e19df7b56a347dd55f04c5e2f8a79 --- /dev/null +++ b/test/ir/inference/test_xpu_matmul_weight_trans_pass.py @@ -0,0 +1,75 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import hypothesis.strategies as st +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestXpuMatmulV2WeightTransPass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + # cpu + config = self.create_inference_config(use_xpu=True) + yield config, [ + "matmul_v2", + ], (1e-3, 1e-3) + + def sample_program_config(self, draw): + # 1. Generate shape and attr of matmul + x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=8), min_size=3, max_size=3 + ) + ) + transpose_shape = x_shape + transpose_op = OpConfig( + "transpose2", + inputs={"X": ["transpose_input"]}, + outputs={"Out": ["transpose_out"]}, + axis=[0, 2, 1], + ) + matmul_op = OpConfig( + "matmul_v2", + inputs={"X": ["matmul_x"], "Y": ["transpose_out"]}, + outputs={"Out": ["matmul_out"]}, + transpose_X=False, + transpose_Y=False, + ) + ops = [transpose_op, matmul_op] + weights = {} + inputs = { + "matmul_x": TensorConfig(shape=x_shape), + "transpose_input": TensorConfig(shape=transpose_shape), + } + program_config = ProgramConfig( + ops=ops, + weights=weights, + inputs=inputs, + outputs=ops[-1].outputs["Out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + min_success_num=5, + passes=["matmul_weight_trans_pass"], + ) + + +if __name__ == "__main__": + unittest.main()