未验证 提交 cc4f5d05 编写于 作者: W wz1qqx 提交者: GitHub

[XPU]Add shape+matmul relative pass (#54574)

上级 36a5ff50
......@@ -260,6 +260,11 @@ if(WITH_XPU)
${XPU_PASS_DEPS})
pass_library(fold_interp_outsize_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fold_two_squeeze2_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(matmul_weight_trans_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(reshape2_matmul_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -547,4 +552,16 @@ if(WITH_XPU)
test_fold_interp_outsize_fuse_pass
SRCS xpu/fold_interp_outsize_fuse_pass_test.cc
DEPS fold_interp_outsize_fuse_pass)
cc_test(
test_fold_two_squeeze2_fuse_pass
SRCS xpu/fold_two_squeeze2_fuse_pass_test.cc
DEPS fold_two_squeeze2_fuse_pass)
cc_test(
test_matmul_weight_trans_pass
SRCS xpu/matmul_weight_trans_pass_test.cc
DEPS matmul_weight_trans_pass)
cc_test(
test_reshape2_matmul_xpu_fuse_pass
SRCS xpu/reshape2_matmul_xpu_fuse_pass_test.cc
DEPS reshape2_matmul_xpu_fuse_pass)
endif()
文件模式从 100755 更改为 100644
......@@ -134,6 +134,22 @@ struct Layers {
return out;
}
VarDesc* squeeze2(VarDesc* x,
const std::vector<int> axes = {-1},
bool with_xshape = false) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("squeeze2");
op->SetInput("X", {x->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("axes", axes);
if (with_xshape) {
VarDesc* xshape = lod_tensor(unique_name());
op->SetOutput("XShape", {xshape->Name()});
}
return out;
}
VarDesc* unsqueeze2(VarDesc* x, const std::vector<int> axes = {-1}) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
......@@ -420,6 +436,17 @@ struct Layers {
return out;
}
VarDesc* clip(VarDesc* x, VarDesc* min, VarDesc* max) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("clip");
op->SetInput("X", {x->Name()});
op->SetInput("Min", {min->Name()});
op->SetInput("Max", {max->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* matmul_v2(VarDesc* x,
VarDesc* y,
VarDesc* alpha = nullptr,
......
......@@ -22,23 +22,14 @@
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct DetectorFusePattern : public PatternBase {
DetectorFusePattern(PDPattern* pattern, const std::string& name_scope);
struct InterpOutsizeFusePattern : public PatternBase {
InterpOutsizeFusePattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(shape);
......@@ -60,8 +51,8 @@ struct DetectorFusePattern : public PatternBase {
PATTERN_DECL_NODE(cast2_out);
};
DetectorFusePattern::DetectorFusePattern(PDPattern* pattern,
const std::string& name_scope)
InterpOutsizeFusePattern::InterpOutsizeFusePattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* x = pattern->NewNode(x_repr())
->assert_is_op_input("shape", "Input")
......@@ -144,9 +135,10 @@ DetectorFusePattern::DetectorFusePattern(PDPattern* pattern,
} // namespace patterns
void FoldInterpOutsizeFusePass::DetectorFuse(ir::Graph* graph) const {
void FoldInterpOutsizeFusePass::FoldInterpOutsize(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::DetectorFusePattern pattern(gpd.mutable_pattern(), name_scope_);
patterns::InterpOutsizeFusePattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
......@@ -213,7 +205,7 @@ void FoldInterpOutsizeFusePass::ApplyImpl(ir::Graph* graph) const {
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
DetectorFuse(graph);
FoldInterpOutsize(graph);
}
} // namespace ir
......
......@@ -64,7 +64,7 @@ class FoldInterpOutsizeFusePass : public FusePassBase {
| /
bilinear_interp_v2
*/
void DetectorFuse(ir::Graph* graph) const;
void FoldInterpOutsize(ir::Graph* graph) const;
const std::string name_scope_{"fold_interp_outsize_fuse_pass"};
};
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace framework {
namespace ir {
TEST(DetectorFuse, basic) {
TEST(FoldInterpOutsizeFusePass, basic) {
Layers layers;
auto* block = layers.Block();
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/fold_two_squeeze2_fuse_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct TwoSqueeze2FusePattern : public PatternBase {
TwoSqueeze2FusePattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(squeeze2_1);
PATTERN_DECL_NODE(squeeze2_2);
// declare variable node's name
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(squeeze2_1_out);
PATTERN_DECL_NODE(squeeze2_2_out);
};
TwoSqueeze2FusePattern::TwoSqueeze2FusePattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* x = pattern->NewNode(x_repr())
->assert_is_op_input("squeeze2", "X")
->AsInput()
->assert_more([](Node* node) {
auto squeeze2_in_x_shape = node->Var()->GetShape();
size_t squeeze2_in_rank = squeeze2_in_x_shape.size();
bool nice_shape = squeeze2_in_x_shape[1] == 1 &&
squeeze2_in_x_shape[2] == 74 &&
squeeze2_in_x_shape[3] == 1;
return squeeze2_in_rank == 4 && nice_shape;
});
auto* squeeze2_1 = pattern->NewNode(squeeze2_1_repr())
->assert_is_op("squeeze2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<std::vector<int>>(
"axes") == std::vector<int>{3};
});
auto* squeeze2_1_out = pattern->NewNode(squeeze2_1_out_repr())
->assert_is_op_output("squeeze2", "Out")
->assert_has_n_outputs(1)
->assert_is_op_input("squeeze2", "X");
squeeze2_1->LinksFrom({x}).LinksTo({squeeze2_1_out});
auto* squeeze2_2 = pattern->NewNode(squeeze2_2_repr())
->assert_is_op("squeeze2")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<std::vector<int>>(
"axes") == std::vector<int>{1};
});
auto* squeeze2_2_out = pattern->NewNode(squeeze2_2_out_repr())
->assert_is_op_output("squeeze2", "Out")
->assert_has_n_outputs(1);
squeeze2_2->LinksFrom({squeeze2_1_out}).LinksTo({squeeze2_2_out});
}
} // namespace patterns
void FoldTwoSqueeze2FusePass::FoldTwoSqueeze2(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::TwoSqueeze2FusePattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FoldTwoSqueeze2FusePass";
// declare operator node's name
GET_IR_NODE(squeeze2_1);
GET_IR_NODE(squeeze2_2);
// declare variable node's name
GET_IR_NODE(x);
GET_IR_NODE(squeeze2_1_out);
GET_IR_NODE(squeeze2_2_out);
auto* block = squeeze2_1->Op()->Block();
// Generate reshape2 op
framework::OpDesc reshape2_op_desc(block);
reshape2_op_desc.SetType("reshape2");
reshape2_op_desc.SetInput("X", {x->Name()});
reshape2_op_desc.SetAttr("shape", std::vector<int>{-1, 74});
reshape2_op_desc.SetOutput("Out", {squeeze2_2_out->Name()});
auto* reshape2 = graph->CreateOpNode(&reshape2_op_desc);
IR_NODE_LINK_TO(x, reshape2);
IR_NODE_LINK_TO(reshape2, squeeze2_2_out);
// delete useless node
std::unordered_set<const Node*> delete_nodes = {
squeeze2_1, squeeze2_2, squeeze2_1_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void FoldTwoSqueeze2FusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
FoldTwoSqueeze2(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fold_two_squeeze2_fuse_pass,
paddle::framework::ir::FoldTwoSqueeze2FusePass);
REGISTER_PASS_CAPABILITY(fold_two_squeeze2_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"squeeze2", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
/*
Origin subgraph:
x
|
squeeze2
|
squeeze2
|
Fused subgraph:
x
|
reshape2
|
*/
class FoldTwoSqueeze2FusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
void FoldTwoSqueeze2(ir::Graph* graph) const;
const std::string name_scope_{"fold_two_squeeze2_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(FoldTwoSqueeze2FusePass, basic) {
Layers layers;
auto* in_x = layers.data("in_x", {64, 1, 74, 1});
auto* squeeze2_1_out = layers.squeeze2(in_x, std::vector<int>{3});
layers.squeeze2(squeeze2_1_out, std::vector<int>{1});
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("fold_two_squeeze2_fuse_pass");
pass->Apply(graph.get());
auto ops_num = GetNumOpNodes(graph);
PADDLE_ENFORCE_EQ(
ops_num,
1,
platform::errors::PreconditionNotMet(
"graph should only have 2 op nodes, but received %d.", ops_num));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fold_two_squeeze2_fuse_pass);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/matmul_weight_trans_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct Reshape2MatmulV2Pattern : public PatternBase {
Reshape2MatmulV2Pattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(reshape2);
PATTERN_DECL_NODE(matmul_v2);
// declare variable node's name
PATTERN_DECL_NODE(reshape2_in);
PATTERN_DECL_NODE(matmul_x);
PATTERN_DECL_NODE(matmul_y);
PATTERN_DECL_NODE(matmul_out);
};
Reshape2MatmulV2Pattern::Reshape2MatmulV2Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* reshape2_in =
pattern->NewNode(reshape2_in_repr())
->assert_is_op_input("reshape2", "X")
->AsInput()
->assert_more([](Node* node) {
auto reshape2_in_x_shape = node->Var()->GetShape();
size_t reshape2_in_rank = reshape2_in_x_shape.size();
return (reshape2_in_rank == 4 && reshape2_in_x_shape[2] == 1 &&
reshape2_in_x_shape[3] == 1);
});
auto* reshape2 = pattern->NewNode(reshape2_repr())->assert_is_op("reshape2");
auto matmul_x = pattern->NewNode(matmul_x_repr())
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("matmul_v2", "X")
->assert_more([](Node* node) {
auto matmul_x_shape = node->Var()->GetShape();
size_t matmul_x_rank = matmul_x_shape.size();
return matmul_x_rank == 2;
});
auto* matmul_y = pattern->NewNode(matmul_y_repr())
->assert_is_op_input("matmul_v2", "Y")
->assert_is_persistable_var()
->assert_more([](Node* node) {
auto matmul_y_shape = node->Var()->GetShape();
size_t matmul_y_rank = matmul_y_shape.size();
return matmul_y_rank == 2;
});
auto* matmul_v2 = pattern->NewNode(matmul_v2_repr())
->assert_is_op("matmul_v2")
->assert_op_attr<bool>("trans_x", false)
->assert_op_attr<bool>("trans_y", true);
auto* matmul_out = pattern->NewNode(matmul_out_repr())
->assert_is_op_output("matmul_v2", "Out")
->AsOutput();
reshape2->LinksFrom({reshape2_in}).LinksTo({matmul_x});
matmul_v2->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out});
}
} // namespace patterns
void MatmulWeightTransPass::TransMatmulV2Weight(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::Reshape2MatmulV2Pattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle TransMatmulV2Weight";
/* declare operator node's name */
GET_IR_NODE(reshape2);
GET_IR_NODE(matmul_v2);
/* declare variable node's name*/
GET_IR_NODE(reshape2_in);
GET_IR_NODE(matmul_x);
GET_IR_NODE(matmul_y);
GET_IR_NODE(matmul_out);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
auto* matmul_y_t =
scope->GetVar(matmul_y->Name())->GetMutable<phi::DenseTensor>();
Transpose2D(matmul_y_t);
auto from_shape = matmul_y->Var()->GetShape();
matmul_y->Var()->SetShape({from_shape[1], from_shape[0]});
matmul_v2->Op()->SetAttr("trans_y", false);
matmul_v2->Op()->Flush();
// delete useless node
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void MatmulWeightTransPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
TransMatmulV2Weight(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(matmul_weight_trans_pass,
paddle::framework::ir::MatmulWeightTransPass);
REGISTER_PASS_CAPABILITY(matmul_weight_trans_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("reshape2", 0)
.EQ("matmul_v2", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
/*
Origin subgraph:
x
|
reshape2
|
matmul_v2(trans_x=fasle, trans_y=true)
|
Fused subgraph:
x
reshape2
|
matmul_v2(trans_x=fasle, trans_y=false)
|
*/
class MatmulWeightTransPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
void TransMatmulV2Weight(ir::Graph* graph) const;
const std::string name_scope_{"matmul_weight_trans_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(MatMulWeightTransPass, basic) {
Layers layers;
auto* reshape2_in = layers.data("reshape2_in", {64, 256, 1, 1});
auto* reshape2_out = layers.reshape2(reshape2_in, std::vector<int>{-1, 256});
auto* matmul_y = layers.data("matmul_y", {8, 256}, true);
layers.matmul_v2(reshape2_out, matmul_y, nullptr, false, true);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("matmul_weight_trans_pass");
VLOG(3) << DebugString(graph);
pass->Apply(graph.get());
VLOG(3) << DebugString(graph);
bool trans_y = true;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "matmul_v2") {
trans_y = PADDLE_GET_CONST(bool, node->Op()->GetAttr("trans_y"));
}
}
PADDLE_ENFORCE_EQ(
trans_y,
false,
platform::errors::PreconditionNotMet(
"The attribute of matmul_v2 trans_y should be false after pass"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(matmul_weight_trans_pass);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.h"
#include <cmath>
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct MatmulV2Pattern : public PatternBase {
MatmulV2Pattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(matmul_v2);
// declare variable node's name
PATTERN_DECL_NODE(matmul_x);
PATTERN_DECL_NODE(matmul_y);
PATTERN_DECL_NODE(matmul_out);
};
MatmulV2Pattern::MatmulV2Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto matmul_x = pattern->NewNode(matmul_x_repr())
->assert_is_op_input("matmul_v2", "X")
->AsInput();
auto* matmul_y = pattern->NewNode(matmul_y_repr())
->assert_is_op_input("matmul_v2", "Y")
->AsInput();
auto* matmul_v2 = pattern->NewNode(matmul_v2_repr())
->assert_is_op("matmul_v2")
->assert_more([](Node* node) {
if (node->inputs.size() != 2) {
return false;
}
return node->inputs[0]->Var()->GetShape().size() ==
node->inputs[1]->Var()->GetShape().size();
});
auto* matmul_out = pattern->NewNode(matmul_out_repr())
->assert_is_op_output("matmul_v2", "Out")
->AsOutput();
matmul_v2->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out});
}
struct Reshape2MatmulPattern : public PatternBase {
Reshape2MatmulPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(reshape2);
PATTERN_DECL_NODE(matmul);
// declare variable node's name
PATTERN_DECL_NODE(reshape2_in);
PATTERN_DECL_NODE(matmul_x);
PATTERN_DECL_NODE(matmul_y);
PATTERN_DECL_NODE(matmul_out);
};
Reshape2MatmulPattern::Reshape2MatmulPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* reshape2_in =
pattern->NewNode(reshape2_in_repr())
->assert_is_op_input("reshape2", "X")
->AsInput()
->assert_more([](Node* node) {
auto reshape2_in_x_shape = node->Var()->GetShape();
size_t reshape2_in_rank = reshape2_in_x_shape.size();
bool nice_shape =
(reshape2_in_x_shape[2] == 1 && reshape2_in_x_shape[3] == 1) ||
(reshape2_in_x_shape[1] == 1 && reshape2_in_x_shape[3] == 1);
return (reshape2_in_rank == 4 && nice_shape);
});
auto* reshape2 =
pattern->NewNode(reshape2_repr())
->assert_is_op("reshape2")
->assert_has_n_inputs(1)
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto reshape2_shape_attr =
op_desc->GetAttrIfExists<std::vector<int>>("shape");
return reshape2_shape_attr.size() == 2;
});
auto matmul_x = pattern->NewNode(matmul_x_repr())
->assert_is_op_output("reshape2", "Out")
->assert_has_n_outputs(1)
->assert_is_op_input("matmul", "X")
->assert_more([](Node* node) {
auto matmul_x_shape = node->Var()->GetShape();
size_t matmul_x_rank = matmul_x_shape.size();
return matmul_x_rank == 2;
});
auto* matmul_y = pattern->NewNode(matmul_y_repr())
->assert_is_op_input("matmul", "Y")
->assert_is_persistable_var()
->assert_more([](Node* node) {
auto matmul_y_shape = node->Var()->GetShape();
size_t matmul_y_rank = matmul_y_shape.size();
return matmul_y_rank == 2;
});
auto* matmul = pattern->NewNode(matmul_repr())
->assert_is_op("matmul")
->assert_op_attr<bool>("transpose_X", false)
->assert_op_attr<bool>("transpose_Y", false);
auto* matmul_out = pattern->NewNode(matmul_out_repr())
->assert_is_op_output("matmul", "Out")
->AsOutput();
reshape2->LinksFrom({reshape2_in}).LinksTo({matmul_x});
matmul->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out});
}
} // namespace patterns
void Reshape2MatmulXPUFusePass::FuseReshape2Matmul(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::Reshape2MatmulPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ReShape2MatmulXPUFusePass";
/* declare operator node's name */
GET_IR_NODE(reshape2);
GET_IR_NODE(matmul);
/* declare variable node's name*/
GET_IR_NODE(reshape2_in);
GET_IR_NODE(matmul_x);
GET_IR_NODE(matmul_y);
GET_IR_NODE(matmul_out);
bool flag = true;
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
(next_ops[0]->Name() == "elementwise_add" ||
next_ops[0]->Name() == "batch_norm");
if (flag) {
OpDesc desc(matmul->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {reshape2_in->Name()});
desc.SetInput("Y", {matmul_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("y_num_col_dims", 1);
auto mul_node = graph->CreateOpNode(&desc);
IR_NODE_LINK_TO(reshape2_in, mul_node);
IR_NODE_LINK_TO(matmul_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {reshape2, matmul_x, matmul});
found_subgraph_count++;
}
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void Reshape2MatmulXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
FuseReshape2Matmul(graph);
}
void MapMatmulV2ToMatmulXPUPass::MapMatmulV2ToMatmul(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::MatmulV2Pattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle MapMatmulV2ToMatmulXPUPass";
/* declare operator node's name */
GET_IR_NODE(matmul_v2);
/* declare variable node's name*/
GET_IR_NODE(matmul_x);
GET_IR_NODE(matmul_y);
GET_IR_NODE(matmul_out);
std::vector<int64_t> x_shape = matmul_x->Var()->GetShape();
std::vector<int64_t> y_shape = matmul_y->Var()->GetShape();
uint64_t dims = 2;
for (size_t i = 0; i < x_shape.size() - dims; ++i) {
if (x_shape[i] != y_shape[i] && (x_shape[i] == 1 || y_shape[i] == 1)) {
LOG(WARNING) << "matmul op not support broadcast, please check "
"inputs'shape[i]. ";
return;
}
}
OpDesc desc(matmul_v2->Op()->Block());
desc.SetType("matmul");
desc.SetInput("X", {matmul_x->Name()});
desc.SetInput("Y", {matmul_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("transpose_X", matmul_v2->Op()->GetAttr("trans_x"));
desc.SetAttr("transpose_Y", matmul_v2->Op()->GetAttr("trans_y"));
desc.SetAttr("alpha", 1.0f);
if (matmul_v2->Op()->HasAttr("use_mkldnn")) {
desc.SetAttr("use_mkldnn", matmul_v2->Op()->GetAttr("use_mkldnn"));
}
auto matmul_node = graph->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_x, matmul_node);
IR_NODE_LINK_TO(matmul_y, matmul_node);
IR_NODE_LINK_TO(matmul_node, matmul_out);
GraphSafeRemoveNodes(graph, {matmul_v2});
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void MapMatmulV2ToMatmulXPUPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
MapMatmulV2ToMatmul(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reshape2_matmul_xpu_fuse_pass,
paddle::framework::ir::Reshape2MatmulXPUFusePass);
REGISTER_PASS_CAPABILITY(reshape2_matmul_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("reshape2", 0)
.LE("matmul", 1)
.EQ("mul", 0));
REGISTER_PASS(map_matmulv2_to_matmul_xpu_pass,
paddle::framework::ir::MapMatmulV2ToMatmulXPUPass);
REGISTER_PASS_CAPABILITY(map_matmulv2_to_matmul_xpu_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.LE("matmul", 1));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class Reshape2MatmulXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
void FuseReshape2Matmul(ir::Graph* graph) const;
const std::string name_scope_{"reshape2_matmul_xpu_fuse_pass"};
};
class MapMatmulV2ToMatmulXPUPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
void MapMatmulV2ToMatmul(ir::Graph* graph) const;
const std::string name_scope_{"map_matmulv2_to_matmul_xpu_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(ReShape2MatmulXPUFusePass, basic) {
Layers layers;
auto* reshape2_in = layers.data("reshape2_in", {64, 1, 74, 1});
auto* reshape2_out = layers.reshape2(reshape2_in, std::vector<int>{-1, 74});
auto* matmul_y = layers.data("matmul_y", {74, 64}, true);
auto* matmul_out =
layers.matmul(reshape2_out, matmul_y, nullptr, false, false);
auto* ele_y = layers.data("ele_y", {64}, true);
layers.elementwise_add(matmul_out, ele_y);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("reshape2_matmul_xpu_fuse_pass");
VLOG(3) << DebugString(graph);
pass->Apply(graph.get());
VLOG(3) << DebugString(graph);
auto ops_num = GetNumOpNodes(graph);
PADDLE_ENFORCE_EQ(
ops_num,
3,
platform::errors::PreconditionNotMet(
"graph should only have 2 op nodes, but received %d.", ops_num));
}
TEST(MapMatmulV2ToMatmulXPUPass, basic) {
Layers layers;
auto* matmul_x = layers.data("matmul_x", {64, 74});
auto* matmul_y = layers.data("matmul_y", {74, 64}, true);
layers.matmul_v2(matmul_x, matmul_y, nullptr, false, false);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("map_matmulv2_to_matmul_xpu_pass");
VLOG(3) << DebugString(graph);
pass->Apply(graph.get());
VLOG(3) << DebugString(graph);
auto matmuls = GetOpNodes(graph, "matmul");
for (auto* matmul : matmuls) {
PADDLE_ENFORCE_EQ(
std::abs(matmul->Op()->GetAttrIfExists<float>("alpha") - 1.f) < 1e-5f,
true,
platform::errors::PreconditionNotMet(
"matmul_v2 is mapped to matmul by pass."));
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(reshape2_matmul_xpu_fuse_pass);
......@@ -523,10 +523,14 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fused_multi_transformer_cachekv_layout_trans_pass",
"one_beam_size_fuse_pass",
"fold_interp_outsize_fuse_pass",
"fold_two_squeeze2_fuse_pass",
"delete_cast_op_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_pass",
"sigmoid_elementmul_fuse_pass",
"matmul_weight_trans_pass",
"map_matmulv2_to_matmul_xpu_pass",
"reshape2_matmul_xpu_fuse_pass",
"fc_xpu_fuse_pass",
"conv2d_xpu_fuse_pass",
"add_activation_xpu_fuse_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册