未验证 提交 f1c8d3fa 编写于 作者: W wz1qqx 提交者: GitHub

add squeeze2+matmul pass (#54779)

上级 ffeac6d5
...@@ -107,6 +107,7 @@ pass_library(constant_folding_pass inference) ...@@ -107,6 +107,7 @@ pass_library(constant_folding_pass inference)
pass_library(auto_mixed_precision_pass inference) pass_library(auto_mixed_precision_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference) pass_library(conv2d_fusion_layout_transfer_pass inference)
pass_library(transfer_layout_elim_pass inference) pass_library(transfer_layout_elim_pass inference)
pass_library(relu6_fuse_pass inference)
pass_library(silu_fuse_pass inference) pass_library(silu_fuse_pass inference)
pass_library(simplify_with_basic_ops_pass base) pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base)
...@@ -434,6 +435,10 @@ cc_test( ...@@ -434,6 +435,10 @@ cc_test(
test_delete_cast_op_pass test_delete_cast_op_pass
SRCS delete_cast_op_pass_test.cc SRCS delete_cast_op_pass_test.cc
DEPS delete_cast_op_pass) DEPS delete_cast_op_pass)
cc_test(
test_relu6_fuse_pass
SRCS relu6_fuse_pass_test.cc
DEPS relu6_fuse_pass)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
cc_test( cc_test(
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/relu6_fuse_pass.h"
#include <cmath>
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
void Relu6FusePass::ApplyImpl(ir::Graph* graph) const {
// This pass is now used for xpu, because xpu can fuse conv + bias + relu6
const std::string pattern_name = "relu6_fuse";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* clip_x = gpd.mutable_pattern()
->NewNode("clip_x")
->assert_is_op_input("clip", "X")
->assert_var_not_persistable()
->AsInput();
auto clip_op =
gpd.mutable_pattern()->NewNode("clip_op")->assert_is_op("clip");
auto clip_min = gpd.mutable_pattern()
->NewNode("clip_min")
->assert_is_op_input("clip", "Min")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 1;
})
->AsInput();
auto clip_max = gpd.mutable_pattern()
->NewNode("clip_max")
->assert_is_op_input("clip", "Max")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 1;
})
->AsInput();
auto clip_out = gpd.mutable_pattern()
->NewNode("clip_out")
->assert_is_op_output("clip")
->AsOutput();
clip_op->LinksFrom({clip_x, clip_min, clip_max}).LinksTo({clip_out});
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
Node* clip_x_node = subgraph.at(clip_x);
Node* clip_op_node = subgraph.at(clip_op);
Node* clip_max_node = subgraph.at(clip_max);
Node* clip_min_node = subgraph.at(clip_min);
Node* clip_out_node = subgraph.at(clip_out);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
const auto& clip_max_t =
scope->GetVar(clip_max_node->Name())->Get<phi::DenseTensor>();
auto clip_max_t_dims = clip_max_t.dims();
PADDLE_ENFORCE_EQ(
clip_max_t_dims.size(),
1,
platform::errors::InvalidArgument("the size(%d) of clip max tensor "
"must equal 1",
clip_max_t_dims.size()));
const auto& clip_min_t =
scope->GetVar(clip_min_node->Name())->Get<phi::DenseTensor>();
auto clip_min_t_dims = clip_min_t.dims();
PADDLE_ENFORCE_EQ(
clip_min_t_dims.size(),
1,
platform::errors::InvalidArgument("the size(%d) of clip max tensor "
"must equal 1",
clip_min_t_dims.size()));
auto tensor_type = clip_max_t.dtype();
float max_val_ = 0.f;
float min_val_ = 1.f;
if (tensor_type == phi::DataType::FLOAT16) {
auto* clip_max_t_fp16_ptr = clip_max_t.data<platform::float16>();
auto* clip_min_t_fp16_ptr = clip_min_t.data<platform::float16>();
max_val_ = static_cast<float>(clip_max_t_fp16_ptr[0]);
min_val_ = static_cast<float>(clip_min_t_fp16_ptr[0]);
} else if (tensor_type == phi::DataType::FLOAT32) {
auto* clip_max_t_fp32_ptr = clip_max_t.data<float>();
auto* clip_min_t_fp32_ptr = clip_min_t.data<float>();
max_val_ = clip_max_t_fp32_ptr[0];
min_val_ = clip_min_t_fp32_ptr[0];
} else {
PADDLE_THROW(platform::errors::Unavailable(
"relu6_fuse_pass do not supported weight dtype. "
"we now only support fp32/fp16."));
}
if (std::abs(max_val_ - 6.0) < 1e-3 && std::abs(min_val_ - 0.0) < 1e-3) {
OpDesc new_desc;
new_desc.SetType("relu6");
new_desc.SetAttr("threshold", 6.f);
new_desc.SetInput("X", {clip_x_node->Name()});
new_desc.SetOutput("Out", {clip_out_node->Name()});
new_desc.Flush();
std::unordered_set<const Node*> del_node_set;
del_node_set.insert(clip_op_node);
del_node_set.insert(clip_max_node);
del_node_set.insert(clip_min_node);
GraphSafeRemoveNodes(graph, del_node_set);
auto fused_node = graph->CreateOpNode(&new_desc);
IR_NODE_LINK_TO(clip_x_node, fused_node);
IR_NODE_LINK_TO(fused_node, clip_out_node);
}
};
gpd(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(relu6_fuse_pass, paddle::framework::ir::Relu6FusePass);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
/*
fuse fill_constant + clip block in to relu6 op
For example:
graph:
Min(0) Input Max(6.0)
\ | /
\ | /
clip
|
|
Output
------------------------------------------------------
After the pass is applied:
Input
|
|
relu6
|
|
Output
*/
class Relu6FusePass : public FusePassBase {
public:
virtual ~Relu6FusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"relu6_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
template <typename T = float>
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims,
T value = 0) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
auto* data = cpu_ctx->Alloc<T>(tensor);
for (int64_t i = 0; i < tensor->numel(); i++) {
data[i] = value;
}
}
TEST(Relu6FusePass, basic) {
Layers layers;
auto* in_x = layers.data("in_x", {1, 32, 112, 112});
auto* clip_min = layers.data("clip_x", {1}, true);
auto* clip_max = layers.data("clip_y", {1}, true);
layers.clip(in_x, clip_min, clip_max);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto* param_scope = new Scope();
graph->Set("__param_scope__", param_scope);
AddVarToScope(param_scope, clip_min->Name(), {1}, 0.f);
AddVarToScope(param_scope, clip_max->Name(), {1}, 6.f);
auto pass = PassRegistry::Instance().Get("relu6_fuse_pass");
VLOG(3) << DebugString(graph);
pass->Apply(graph.get());
VLOG(3) << DebugString(graph);
auto clip_num = GetNumOpNodes(graph, "clip");
PADDLE_ENFORCE_EQ(clip_num,
0,
platform::errors::PreconditionNotMet(
"clip should be mapped to relu6 after pass."));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(relu6_fuse_pass);
...@@ -139,6 +139,76 @@ Reshape2MatmulPattern::Reshape2MatmulPattern(PDPattern* pattern, ...@@ -139,6 +139,76 @@ Reshape2MatmulPattern::Reshape2MatmulPattern(PDPattern* pattern,
reshape2->LinksFrom({reshape2_in}).LinksTo({matmul_x}); reshape2->LinksFrom({reshape2_in}).LinksTo({matmul_x});
matmul->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out}); matmul->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out});
} }
struct Squeeze2MatmulPattern : public PatternBase {
Squeeze2MatmulPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(squeeze2);
PATTERN_DECL_NODE(matmul);
// declare variable node's name
PATTERN_DECL_NODE(squeeze2_in);
PATTERN_DECL_NODE(matmul_x);
PATTERN_DECL_NODE(matmul_y);
PATTERN_DECL_NODE(matmul_out);
};
Squeeze2MatmulPattern::Squeeze2MatmulPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* squeeze2_in =
pattern->NewNode(squeeze2_in_repr())
->assert_is_op_input("squeeze2", "X")
->AsInput()
->assert_more([](Node* node) {
auto squeeze2_in_x_shape = node->Var()->GetShape();
size_t squeeze2_in_rank = squeeze2_in_x_shape.size();
bool nice_shape =
squeeze2_in_x_shape[2] == 1 && squeeze2_in_x_shape[3] == 1;
return squeeze2_in_rank == 4 && nice_shape;
});
auto* squeeze2 = pattern->NewNode(squeeze2_repr())
->assert_is_op("squeeze2")
->assert_has_n_inputs(1)
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto squeeze2_op_axes =
op_desc->GetAttrIfExists<std::vector<int>>("axes");
return squeeze2_op_axes == std::vector<int>{2, 3};
});
auto matmul_x = pattern->NewNode(matmul_x_repr())
->assert_is_op_output("squeeze2", "Out")
->assert_has_n_outputs(1)
->assert_is_op_input("matmul", "X")
->assert_more([](Node* node) {
auto matmul_x_shape = node->Var()->GetShape();
size_t matmul_x_rank = matmul_x_shape.size();
return matmul_x_rank == 2;
});
auto* matmul_y = pattern->NewNode(matmul_y_repr())
->assert_is_op_input("matmul", "Y")
->assert_is_persistable_var()
->assert_more([](Node* node) {
auto matmul_y_shape = node->Var()->GetShape();
size_t matmul_y_rank = matmul_y_shape.size();
return matmul_y_rank == 2;
});
auto* matmul = pattern->NewNode(matmul_repr())
->assert_is_op("matmul")
->assert_op_attr<bool>("transpose_X", false)
->assert_op_attr<bool>("transpose_Y", false)
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto matmul_alpha_attr =
op_desc->GetAttrIfExists<float>("alpha");
return std::abs(matmul_alpha_attr - 1.f) < 1e-5;
});
auto* matmul_out = pattern->NewNode(matmul_out_repr())
->assert_is_op_output("matmul", "Out")
->AsOutput();
squeeze2->LinksFrom({squeeze2_in}).LinksTo({matmul_x});
matmul->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out});
}
} // namespace patterns } // namespace patterns
void Reshape2MatmulXPUFusePass::FuseReshape2Matmul(ir::Graph* graph) const { void Reshape2MatmulXPUFusePass::FuseReshape2Matmul(ir::Graph* graph) const {
...@@ -250,6 +320,59 @@ void MapMatmulV2ToMatmulXPUPass::ApplyImpl(ir::Graph* graph) const { ...@@ -250,6 +320,59 @@ void MapMatmulV2ToMatmulXPUPass::ApplyImpl(ir::Graph* graph) const {
MapMatmulV2ToMatmul(graph); MapMatmulV2ToMatmul(graph);
} }
void Squeeze2MatmulXPUFusePass::FuseSqueeze2Matmul(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::Squeeze2MatmulPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle Squeeze2MatmulXPUFusePass";
/* declare operator node's name */
GET_IR_NODE(squeeze2);
GET_IR_NODE(matmul);
/* declare variable node's name*/
GET_IR_NODE(squeeze2_in);
GET_IR_NODE(matmul_x);
GET_IR_NODE(matmul_y);
GET_IR_NODE(matmul_out);
bool flag = true;
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
(next_ops[0]->Name() == "elementwise_add" ||
next_ops[0]->Name() == "batch_norm");
if (flag) {
OpDesc desc(matmul->Op()->Block());
desc.SetType("mul");
desc.SetInput("X", {squeeze2_in->Name()});
desc.SetInput("Y", {matmul_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("y_num_col_dims", 1);
auto mul_node = graph->CreateOpNode(&desc);
IR_NODE_LINK_TO(squeeze2_in, mul_node);
IR_NODE_LINK_TO(matmul_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {squeeze2, matmul_x, matmul});
found_subgraph_count++;
}
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void Squeeze2MatmulXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
FuseSqueeze2Matmul(graph);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -272,3 +395,13 @@ REGISTER_PASS_CAPABILITY(map_matmulv2_to_matmul_xpu_pass) ...@@ -272,3 +395,13 @@ REGISTER_PASS_CAPABILITY(map_matmulv2_to_matmul_xpu_pass)
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0) .EQ("matmul_v2", 0)
.LE("matmul", 1)); .LE("matmul", 1));
REGISTER_PASS(squeeze2_matmul_xpu_fuse_pass,
paddle::framework::ir::Squeeze2MatmulXPUFusePass);
REGISTER_PASS_CAPABILITY(squeeze2_matmul_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("squeeze2", 0)
.LE("matmul", 1)
.EQ("mul", 0));
...@@ -31,6 +31,15 @@ namespace paddle { ...@@ -31,6 +31,15 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class Squeeze2MatmulXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
void FuseSqueeze2Matmul(ir::Graph* graph) const;
const std::string name_scope_{"squeeze2_matmul_xpu_fuse_pass"};
};
class Reshape2MatmulXPUFusePass : public FusePassBase { class Reshape2MatmulXPUFusePass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
......
...@@ -22,6 +22,32 @@ namespace paddle { ...@@ -22,6 +22,32 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
TEST(Squeeze2MatmulXPUFusePass, basic) {
Layers layers;
auto* squeeze2_in = layers.data("squeeze2_in", {64, 1, 74, 1});
auto* squeeze2_out = layers.squeeze2(squeeze2_in, std::vector<int>{1, 3});
auto* matmul_y = layers.data("matmul_y", {74, 64}, true);
auto* matmul_out =
layers.matmul(squeeze2_out, matmul_y, nullptr, false, false);
auto* ele_y = layers.data("ele_y", {64}, true);
layers.elementwise_add(matmul_out, ele_y);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("squeeze2_matmul_xpu_fuse_pass");
VLOG(3) << DebugString(graph);
pass->Apply(graph.get());
VLOG(3) << DebugString(graph);
auto ops_num = GetNumOpNodes(graph);
PADDLE_ENFORCE_EQ(
ops_num,
3,
platform::errors::PreconditionNotMet(
"graph should only have 2 op nodes, but received %d.", ops_num));
}
TEST(ReShape2MatmulXPUFusePass, basic) { TEST(ReShape2MatmulXPUFusePass, basic) {
Layers layers; Layers layers;
......
...@@ -529,10 +529,12 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -529,10 +529,12 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"xpu_delete_cast_op_pass", "xpu_delete_cast_op_pass",
"stack_fuse_pass", "stack_fuse_pass",
"fused_multi_transformer_xpu_pass", "fused_multi_transformer_xpu_pass",
"relu6_fuse_pass",
"sigmoid_elementmul_fuse_pass", "sigmoid_elementmul_fuse_pass",
"matmul_weight_trans_pass", "matmul_weight_trans_pass",
"map_matmulv2_to_matmul_xpu_pass", "map_matmulv2_to_matmul_xpu_pass",
"reshape2_matmul_xpu_fuse_pass", "reshape2_matmul_xpu_fuse_pass",
"squeeze2_matmul_xpu_fuse_pass",
"redundant_squeeze_unsqueeze_elimination_pass", "redundant_squeeze_unsqueeze_elimination_pass",
"fc_xpu_fuse_pass", "fc_xpu_fuse_pass",
"conv2d_xpu_fuse_pass", "conv2d_xpu_fuse_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册