diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 873868c4ac8afbe428504733b93f8a6498eec806..88d4cd7a5e4ed6d494f1dd70d25c1c90fa1e822c 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -260,6 +260,11 @@ if(WITH_XPU) ${XPU_PASS_DEPS}) pass_library(fold_interp_outsize_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(fold_two_squeeze2_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) + pass_library(matmul_weight_trans_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(reshape2_matmul_xpu_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) endif() cc_library( @@ -547,4 +552,16 @@ if(WITH_XPU) test_fold_interp_outsize_fuse_pass SRCS xpu/fold_interp_outsize_fuse_pass_test.cc DEPS fold_interp_outsize_fuse_pass) + cc_test( + test_fold_two_squeeze2_fuse_pass + SRCS xpu/fold_two_squeeze2_fuse_pass_test.cc + DEPS fold_two_squeeze2_fuse_pass) + cc_test( + test_matmul_weight_trans_pass + SRCS xpu/matmul_weight_trans_pass_test.cc + DEPS matmul_weight_trans_pass) + cc_test( + test_reshape2_matmul_xpu_fuse_pass + SRCS xpu/reshape2_matmul_xpu_fuse_pass_test.cc + DEPS reshape2_matmul_xpu_fuse_pass) endif() diff --git a/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc old mode 100755 new mode 100644 diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index ed52eb3190c50609c38547270213a1139800e22c..1fb8dd1a8db0cc1ab8360c880f5ac36132c7d1df 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -134,6 +134,22 @@ struct Layers { return out; } + VarDesc* squeeze2(VarDesc* x, + const std::vector axes = {-1}, + bool with_xshape = false) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("squeeze2"); + op->SetInput("X", {x->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr("axes", axes); + if (with_xshape) { + VarDesc* xshape = lod_tensor(unique_name()); + op->SetOutput("XShape", {xshape->Name()}); + } + return out; + } + VarDesc* unsqueeze2(VarDesc* x, const std::vector axes = {-1}) { VarDesc* out = lod_tensor(unique_name()); OpDesc* op = program_.MutableBlock(0)->AppendOp(); @@ -420,6 +436,17 @@ struct Layers { return out; } + VarDesc* clip(VarDesc* x, VarDesc* min, VarDesc* max) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("clip"); + op->SetInput("X", {x->Name()}); + op->SetInput("Min", {min->Name()}); + op->SetInput("Max", {max->Name()}); + op->SetOutput("Out", {out->Name()}); + return out; + } + VarDesc* matmul_v2(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr, diff --git a/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.cc index 0a3db6c73be1e2e3894e0bec5431fd4ca19155d5..45593dfc63e436df7b0baf9b7c992add22574a89 100644 --- a/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.cc @@ -22,23 +22,14 @@ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" -namespace phi { -class DenseTensor; -} // namespace phi - -namespace paddle { -namespace framework { -class Scope; -} // namespace framework -} // namespace paddle - namespace paddle { namespace framework { namespace ir { namespace patterns { -struct DetectorFusePattern : public PatternBase { - DetectorFusePattern(PDPattern* pattern, const std::string& name_scope); + +struct InterpOutsizeFusePattern : public PatternBase { + InterpOutsizeFusePattern(PDPattern* pattern, const std::string& name_scope); // declare operator node's name PATTERN_DECL_NODE(shape); @@ -60,8 +51,8 @@ struct DetectorFusePattern : public PatternBase { PATTERN_DECL_NODE(cast2_out); }; -DetectorFusePattern::DetectorFusePattern(PDPattern* pattern, - const std::string& name_scope) +InterpOutsizeFusePattern::InterpOutsizeFusePattern( + PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, name_scope) { auto* x = pattern->NewNode(x_repr()) ->assert_is_op_input("shape", "Input") @@ -144,9 +135,10 @@ DetectorFusePattern::DetectorFusePattern(PDPattern* pattern, } // namespace patterns -void FoldInterpOutsizeFusePass::DetectorFuse(ir::Graph* graph) const { +void FoldInterpOutsizeFusePass::FoldInterpOutsize(ir::Graph* graph) const { GraphPatternDetector gpd; - patterns::DetectorFusePattern pattern(gpd.mutable_pattern(), name_scope_); + patterns::InterpOutsizeFusePattern pattern(gpd.mutable_pattern(), + name_scope_); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -213,7 +205,7 @@ void FoldInterpOutsizeFusePass::ApplyImpl(ir::Graph* graph) const { graph, platform::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); - DetectorFuse(graph); + FoldInterpOutsize(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h index 08dc0fe7b73976b2e209ccf03e71bdd3d76ab8ef..01b94ea90b4ff81efc6e566c4d56efa7eab01e12 100644 --- a/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h +++ b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h @@ -64,7 +64,7 @@ class FoldInterpOutsizeFusePass : public FusePassBase { | / bilinear_interp_v2 */ - void DetectorFuse(ir::Graph* graph) const; + void FoldInterpOutsize(ir::Graph* graph) const; const std::string name_scope_{"fold_interp_outsize_fuse_pass"}; }; diff --git a/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass_test.cc b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass_test.cc index e7836a27b4561c353274f7c34f62be601ed48373..e39a3c45c944bb2e93d3e0e6802f1add7c2efc39 100644 --- a/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass_test.cc +++ b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass_test.cc @@ -20,7 +20,7 @@ namespace paddle { namespace framework { namespace ir { -TEST(DetectorFuse, basic) { +TEST(FoldInterpOutsizeFusePass, basic) { Layers layers; auto* block = layers.Block(); diff --git a/paddle/fluid/framework/ir/xpu/fold_two_squeeze2_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fold_two_squeeze2_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..e04bd9692545d1a39af566e2c5f4b620b857ef40 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/fold_two_squeeze2_fuse_pass.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/fold_two_squeeze2_fuse_pass.h" +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +namespace patterns { + +struct TwoSqueeze2FusePattern : public PatternBase { + TwoSqueeze2FusePattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(squeeze2_1); + PATTERN_DECL_NODE(squeeze2_2); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(squeeze2_1_out); + PATTERN_DECL_NODE(squeeze2_2_out); +}; + +TwoSqueeze2FusePattern::TwoSqueeze2FusePattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* x = pattern->NewNode(x_repr()) + ->assert_is_op_input("squeeze2", "X") + ->AsInput() + ->assert_more([](Node* node) { + auto squeeze2_in_x_shape = node->Var()->GetShape(); + size_t squeeze2_in_rank = squeeze2_in_x_shape.size(); + bool nice_shape = squeeze2_in_x_shape[1] == 1 && + squeeze2_in_x_shape[2] == 74 && + squeeze2_in_x_shape[3] == 1; + return squeeze2_in_rank == 4 && nice_shape; + }); + auto* squeeze2_1 = pattern->NewNode(squeeze2_1_repr()) + ->assert_is_op("squeeze2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists>( + "axes") == std::vector{3}; + }); + auto* squeeze2_1_out = pattern->NewNode(squeeze2_1_out_repr()) + ->assert_is_op_output("squeeze2", "Out") + ->assert_has_n_outputs(1) + ->assert_is_op_input("squeeze2", "X"); + squeeze2_1->LinksFrom({x}).LinksTo({squeeze2_1_out}); + auto* squeeze2_2 = pattern->NewNode(squeeze2_2_repr()) + ->assert_is_op("squeeze2") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists>( + "axes") == std::vector{1}; + }); + auto* squeeze2_2_out = pattern->NewNode(squeeze2_2_out_repr()) + ->assert_is_op_output("squeeze2", "Out") + ->assert_has_n_outputs(1); + squeeze2_2->LinksFrom({squeeze2_1_out}).LinksTo({squeeze2_2_out}); +} + +} // namespace patterns + +void FoldTwoSqueeze2FusePass::FoldTwoSqueeze2(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::TwoSqueeze2FusePattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle FoldTwoSqueeze2FusePass"; + // declare operator node's name + GET_IR_NODE(squeeze2_1); + GET_IR_NODE(squeeze2_2); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(squeeze2_1_out); + GET_IR_NODE(squeeze2_2_out); + + auto* block = squeeze2_1->Op()->Block(); + // Generate reshape2 op + framework::OpDesc reshape2_op_desc(block); + reshape2_op_desc.SetType("reshape2"); + reshape2_op_desc.SetInput("X", {x->Name()}); + reshape2_op_desc.SetAttr("shape", std::vector{-1, 74}); + reshape2_op_desc.SetOutput("Out", {squeeze2_2_out->Name()}); + + auto* reshape2 = graph->CreateOpNode(&reshape2_op_desc); + + IR_NODE_LINK_TO(x, reshape2); + IR_NODE_LINK_TO(reshape2, squeeze2_2_out); + // delete useless node + std::unordered_set delete_nodes = { + squeeze2_1, squeeze2_2, squeeze2_1_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void FoldTwoSqueeze2FusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + FoldTwoSqueeze2(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fold_two_squeeze2_fuse_pass, + paddle::framework::ir::FoldTwoSqueeze2FusePass); + +REGISTER_PASS_CAPABILITY(fold_two_squeeze2_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "squeeze2", 0)); diff --git a/paddle/fluid/framework/ir/xpu/fold_two_squeeze2_fuse_pass.h b/paddle/fluid/framework/ir/xpu/fold_two_squeeze2_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..af67754d766568c81806812917738b9024844a06 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/fold_two_squeeze2_fuse_pass.h @@ -0,0 +1,60 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +/* +Origin subgraph: + x + | + squeeze2 + | + squeeze2 + | + +Fused subgraph: + x + | + reshape2 + | +*/ +class FoldTwoSqueeze2FusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void FoldTwoSqueeze2(ir::Graph* graph) const; + + const std::string name_scope_{"fold_two_squeeze2_fuse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/fold_two_squeeze2_fuse_pass_test.cc b/paddle/fluid/framework/ir/xpu/fold_two_squeeze2_fuse_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..64f8a7728089abeadb5cdd0301cbc02c1a266865 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/fold_two_squeeze2_fuse_pass_test.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(FoldTwoSqueeze2FusePass, basic) { + Layers layers; + + auto* in_x = layers.data("in_x", {64, 1, 74, 1}); + auto* squeeze2_1_out = layers.squeeze2(in_x, std::vector{3}); + layers.squeeze2(squeeze2_1_out, std::vector{1}); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("fold_two_squeeze2_fuse_pass"); + pass->Apply(graph.get()); + auto ops_num = GetNumOpNodes(graph); + PADDLE_ENFORCE_EQ( + ops_num, + 1, + platform::errors::PreconditionNotMet( + "graph should only have 2 op nodes, but received %d.", ops_num)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fold_two_squeeze2_fuse_pass); diff --git a/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.cc b/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..92a2634938bd3697712c1733a5449e8c24635238 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.h" +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +namespace patterns { +struct Reshape2MatmulV2Pattern : public PatternBase { + Reshape2MatmulV2Pattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(reshape2); + PATTERN_DECL_NODE(matmul_v2); + // declare variable node's name + PATTERN_DECL_NODE(reshape2_in); + PATTERN_DECL_NODE(matmul_x); + PATTERN_DECL_NODE(matmul_y); + PATTERN_DECL_NODE(matmul_out); +}; + +Reshape2MatmulV2Pattern::Reshape2MatmulV2Pattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* reshape2_in = + pattern->NewNode(reshape2_in_repr()) + ->assert_is_op_input("reshape2", "X") + ->AsInput() + ->assert_more([](Node* node) { + auto reshape2_in_x_shape = node->Var()->GetShape(); + size_t reshape2_in_rank = reshape2_in_x_shape.size(); + return (reshape2_in_rank == 4 && reshape2_in_x_shape[2] == 1 && + reshape2_in_x_shape[3] == 1); + }); + auto* reshape2 = pattern->NewNode(reshape2_repr())->assert_is_op("reshape2"); + auto matmul_x = pattern->NewNode(matmul_x_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("matmul_v2", "X") + ->assert_more([](Node* node) { + auto matmul_x_shape = node->Var()->GetShape(); + size_t matmul_x_rank = matmul_x_shape.size(); + return matmul_x_rank == 2; + }); + auto* matmul_y = pattern->NewNode(matmul_y_repr()) + ->assert_is_op_input("matmul_v2", "Y") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + auto matmul_y_shape = node->Var()->GetShape(); + size_t matmul_y_rank = matmul_y_shape.size(); + return matmul_y_rank == 2; + }); + auto* matmul_v2 = pattern->NewNode(matmul_v2_repr()) + ->assert_is_op("matmul_v2") + ->assert_op_attr("trans_x", false) + ->assert_op_attr("trans_y", true); + auto* matmul_out = pattern->NewNode(matmul_out_repr()) + ->assert_is_op_output("matmul_v2", "Out") + ->AsOutput(); + reshape2->LinksFrom({reshape2_in}).LinksTo({matmul_x}); + matmul_v2->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out}); +} +} // namespace patterns + +void MatmulWeightTransPass::TransMatmulV2Weight(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::Reshape2MatmulV2Pattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle TransMatmulV2Weight"; + /* declare operator node's name */ + GET_IR_NODE(reshape2); + GET_IR_NODE(matmul_v2); + /* declare variable node's name*/ + GET_IR_NODE(reshape2_in); + GET_IR_NODE(matmul_x); + GET_IR_NODE(matmul_y); + GET_IR_NODE(matmul_out); + + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + + auto* matmul_y_t = + scope->GetVar(matmul_y->Name())->GetMutable(); + Transpose2D(matmul_y_t); + auto from_shape = matmul_y->Var()->GetShape(); + matmul_y->Var()->SetShape({from_shape[1], from_shape[0]}); + matmul_v2->Op()->SetAttr("trans_y", false); + matmul_v2->Op()->Flush(); + // delete useless node + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void MatmulWeightTransPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + TransMatmulV2Weight(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(matmul_weight_trans_pass, + paddle::framework::ir::MatmulWeightTransPass); + +REGISTER_PASS_CAPABILITY(matmul_weight_trans_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("reshape2", 0) + .EQ("matmul_v2", 0)); diff --git a/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.h b/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..2eeaeaaca3c89b5be156f6293d12bab11488db27 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.h @@ -0,0 +1,60 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +/* +Origin subgraph: + x + | + reshape2 + | + matmul_v2(trans_x=fasle, trans_y=true) + | +Fused subgraph: + x + reshape2 + | + matmul_v2(trans_x=fasle, trans_y=false) + | +*/ +class MatmulWeightTransPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void TransMatmulV2Weight(ir::Graph* graph) const; + + const std::string name_scope_{"matmul_weight_trans_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass_test.cc b/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1dfaba5d11f89acdf54a7414d616055a22744c3d --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass_test.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(MatMulWeightTransPass, basic) { + Layers layers; + + auto* reshape2_in = layers.data("reshape2_in", {64, 256, 1, 1}); + auto* reshape2_out = layers.reshape2(reshape2_in, std::vector{-1, 256}); + auto* matmul_y = layers.data("matmul_y", {8, 256}, true); + layers.matmul_v2(reshape2_out, matmul_y, nullptr, false, true); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("matmul_weight_trans_pass"); + VLOG(3) << DebugString(graph); + pass->Apply(graph.get()); + VLOG(3) << DebugString(graph); + + bool trans_y = true; + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "matmul_v2") { + trans_y = PADDLE_GET_CONST(bool, node->Op()->GetAttr("trans_y")); + } + } + PADDLE_ENFORCE_EQ( + trans_y, + false, + platform::errors::PreconditionNotMet( + "The attribute of matmul_v2 trans_y should be false after pass")); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(matmul_weight_trans_pass); diff --git a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c56113b0f4c4b358ceffe8047aa6d144a668e75 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc @@ -0,0 +1,274 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.h" + +#include +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +namespace patterns { +struct MatmulV2Pattern : public PatternBase { + MatmulV2Pattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(matmul_v2); + // declare variable node's name + PATTERN_DECL_NODE(matmul_x); + PATTERN_DECL_NODE(matmul_y); + PATTERN_DECL_NODE(matmul_out); +}; + +MatmulV2Pattern::MatmulV2Pattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto matmul_x = pattern->NewNode(matmul_x_repr()) + ->assert_is_op_input("matmul_v2", "X") + ->AsInput(); + auto* matmul_y = pattern->NewNode(matmul_y_repr()) + ->assert_is_op_input("matmul_v2", "Y") + ->AsInput(); + auto* matmul_v2 = pattern->NewNode(matmul_v2_repr()) + ->assert_is_op("matmul_v2") + ->assert_more([](Node* node) { + if (node->inputs.size() != 2) { + return false; + } + return node->inputs[0]->Var()->GetShape().size() == + node->inputs[1]->Var()->GetShape().size(); + }); + auto* matmul_out = pattern->NewNode(matmul_out_repr()) + ->assert_is_op_output("matmul_v2", "Out") + ->AsOutput(); + matmul_v2->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out}); +} + +struct Reshape2MatmulPattern : public PatternBase { + Reshape2MatmulPattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(reshape2); + PATTERN_DECL_NODE(matmul); + // declare variable node's name + PATTERN_DECL_NODE(reshape2_in); + PATTERN_DECL_NODE(matmul_x); + PATTERN_DECL_NODE(matmul_y); + PATTERN_DECL_NODE(matmul_out); +}; + +Reshape2MatmulPattern::Reshape2MatmulPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* reshape2_in = + pattern->NewNode(reshape2_in_repr()) + ->assert_is_op_input("reshape2", "X") + ->AsInput() + ->assert_more([](Node* node) { + auto reshape2_in_x_shape = node->Var()->GetShape(); + size_t reshape2_in_rank = reshape2_in_x_shape.size(); + bool nice_shape = + (reshape2_in_x_shape[2] == 1 && reshape2_in_x_shape[3] == 1) || + (reshape2_in_x_shape[1] == 1 && reshape2_in_x_shape[3] == 1); + return (reshape2_in_rank == 4 && nice_shape); + }); + auto* reshape2 = + pattern->NewNode(reshape2_repr()) + ->assert_is_op("reshape2") + ->assert_has_n_inputs(1) + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto reshape2_shape_attr = + op_desc->GetAttrIfExists>("shape"); + return reshape2_shape_attr.size() == 2; + }); + auto matmul_x = pattern->NewNode(matmul_x_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_has_n_outputs(1) + ->assert_is_op_input("matmul", "X") + ->assert_more([](Node* node) { + auto matmul_x_shape = node->Var()->GetShape(); + size_t matmul_x_rank = matmul_x_shape.size(); + return matmul_x_rank == 2; + }); + auto* matmul_y = pattern->NewNode(matmul_y_repr()) + ->assert_is_op_input("matmul", "Y") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + auto matmul_y_shape = node->Var()->GetShape(); + size_t matmul_y_rank = matmul_y_shape.size(); + return matmul_y_rank == 2; + }); + auto* matmul = pattern->NewNode(matmul_repr()) + ->assert_is_op("matmul") + ->assert_op_attr("transpose_X", false) + ->assert_op_attr("transpose_Y", false); + auto* matmul_out = pattern->NewNode(matmul_out_repr()) + ->assert_is_op_output("matmul", "Out") + ->AsOutput(); + reshape2->LinksFrom({reshape2_in}).LinksTo({matmul_x}); + matmul->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out}); +} +} // namespace patterns + +void Reshape2MatmulXPUFusePass::FuseReshape2Matmul(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::Reshape2MatmulPattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ReShape2MatmulXPUFusePass"; + /* declare operator node's name */ + GET_IR_NODE(reshape2); + GET_IR_NODE(matmul); + /* declare variable node's name*/ + GET_IR_NODE(reshape2_in); + GET_IR_NODE(matmul_x); + GET_IR_NODE(matmul_y); + GET_IR_NODE(matmul_out); + + bool flag = true; + std::vector& next_ops = matmul_out->outputs; + flag = flag && next_ops.size() == 1 && + (next_ops[0]->Name() == "elementwise_add" || + next_ops[0]->Name() == "batch_norm"); + + if (flag) { + OpDesc desc(matmul->Op()->Block()); + desc.SetType("mul"); + desc.SetInput("X", {reshape2_in->Name()}); + desc.SetInput("Y", {matmul_y->Name()}); + desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetAttr("x_num_col_dims", 1); + desc.SetAttr("y_num_col_dims", 1); + + auto mul_node = graph->CreateOpNode(&desc); + IR_NODE_LINK_TO(reshape2_in, mul_node); + IR_NODE_LINK_TO(matmul_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_out); + GraphSafeRemoveNodes(graph, {reshape2, matmul_x, matmul}); + found_subgraph_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void Reshape2MatmulXPUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + FuseReshape2Matmul(graph); +} + +void MapMatmulV2ToMatmulXPUPass::MapMatmulV2ToMatmul(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::MatmulV2Pattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle MapMatmulV2ToMatmulXPUPass"; + /* declare operator node's name */ + GET_IR_NODE(matmul_v2); + /* declare variable node's name*/ + GET_IR_NODE(matmul_x); + GET_IR_NODE(matmul_y); + GET_IR_NODE(matmul_out); + + std::vector x_shape = matmul_x->Var()->GetShape(); + std::vector y_shape = matmul_y->Var()->GetShape(); + uint64_t dims = 2; + for (size_t i = 0; i < x_shape.size() - dims; ++i) { + if (x_shape[i] != y_shape[i] && (x_shape[i] == 1 || y_shape[i] == 1)) { + LOG(WARNING) << "matmul op not support broadcast, please check " + "inputs'shape[i]. "; + return; + } + } + OpDesc desc(matmul_v2->Op()->Block()); + desc.SetType("matmul"); + desc.SetInput("X", {matmul_x->Name()}); + desc.SetInput("Y", {matmul_y->Name()}); + desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetAttr("transpose_X", matmul_v2->Op()->GetAttr("trans_x")); + desc.SetAttr("transpose_Y", matmul_v2->Op()->GetAttr("trans_y")); + desc.SetAttr("alpha", 1.0f); + if (matmul_v2->Op()->HasAttr("use_mkldnn")) { + desc.SetAttr("use_mkldnn", matmul_v2->Op()->GetAttr("use_mkldnn")); + } + auto matmul_node = graph->CreateOpNode(&desc); + IR_NODE_LINK_TO(matmul_x, matmul_node); + IR_NODE_LINK_TO(matmul_y, matmul_node); + IR_NODE_LINK_TO(matmul_node, matmul_out); + GraphSafeRemoveNodes(graph, {matmul_v2}); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void MapMatmulV2ToMatmulXPUPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + MapMatmulV2ToMatmul(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(reshape2_matmul_xpu_fuse_pass, + paddle::framework::ir::Reshape2MatmulXPUFusePass); + +REGISTER_PASS_CAPABILITY(reshape2_matmul_xpu_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("reshape2", 0) + .LE("matmul", 1) + .EQ("mul", 0)); + +REGISTER_PASS(map_matmulv2_to_matmul_xpu_pass, + paddle::framework::ir::MapMatmulV2ToMatmulXPUPass); + +REGISTER_PASS_CAPABILITY(map_matmulv2_to_matmul_xpu_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul_v2", 0) + .LE("matmul", 1)); diff --git a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.h b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..bc16ea9c4057c46905081cc5111b41c694d4b696 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.h @@ -0,0 +1,54 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +class Reshape2MatmulXPUFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void FuseReshape2Matmul(ir::Graph* graph) const; + const std::string name_scope_{"reshape2_matmul_xpu_fuse_pass"}; +}; + +class MapMatmulV2ToMatmulXPUPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void MapMatmulV2ToMatmul(ir::Graph* graph) const; + const std::string name_scope_{"map_matmulv2_to_matmul_xpu_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass_test.cc b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f81e783fc051985a47b693c91660b2d19309cdb5 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass_test.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(ReShape2MatmulXPUFusePass, basic) { + Layers layers; + + auto* reshape2_in = layers.data("reshape2_in", {64, 1, 74, 1}); + auto* reshape2_out = layers.reshape2(reshape2_in, std::vector{-1, 74}); + auto* matmul_y = layers.data("matmul_y", {74, 64}, true); + auto* matmul_out = + layers.matmul(reshape2_out, matmul_y, nullptr, false, false); + auto* ele_y = layers.data("ele_y", {64}, true); + layers.elementwise_add(matmul_out, ele_y); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("reshape2_matmul_xpu_fuse_pass"); + VLOG(3) << DebugString(graph); + + pass->Apply(graph.get()); + VLOG(3) << DebugString(graph); + + auto ops_num = GetNumOpNodes(graph); + PADDLE_ENFORCE_EQ( + ops_num, + 3, + platform::errors::PreconditionNotMet( + "graph should only have 2 op nodes, but received %d.", ops_num)); +} + +TEST(MapMatmulV2ToMatmulXPUPass, basic) { + Layers layers; + + auto* matmul_x = layers.data("matmul_x", {64, 74}); + auto* matmul_y = layers.data("matmul_y", {74, 64}, true); + layers.matmul_v2(matmul_x, matmul_y, nullptr, false, false); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("map_matmulv2_to_matmul_xpu_pass"); + VLOG(3) << DebugString(graph); + + pass->Apply(graph.get()); + VLOG(3) << DebugString(graph); + + auto matmuls = GetOpNodes(graph, "matmul"); + for (auto* matmul : matmuls) { + PADDLE_ENFORCE_EQ( + std::abs(matmul->Op()->GetAttrIfExists("alpha") - 1.f) < 1e-5f, + true, + platform::errors::PreconditionNotMet( + "matmul_v2 is mapped to matmul by pass.")); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(reshape2_matmul_xpu_fuse_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index e13cf391b7a405e13499c51f217da65bcee2a3f0..b8832132044dbedf31c5bd660a15db9b7a1f8bc7 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -523,10 +523,14 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "fused_multi_transformer_cachekv_layout_trans_pass", "one_beam_size_fuse_pass", "fold_interp_outsize_fuse_pass", + "fold_two_squeeze2_fuse_pass", "delete_cast_op_pass", "stack_fuse_pass", "fused_multi_transformer_xpu_pass", "sigmoid_elementmul_fuse_pass", + "matmul_weight_trans_pass", + "map_matmulv2_to_matmul_xpu_pass", + "reshape2_matmul_xpu_fuse_pass", "fc_xpu_fuse_pass", "conv2d_xpu_fuse_pass", "add_activation_xpu_fuse_pass",