diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index e6a4c5e8e73ff5bbc44fa9c832e11e8ed7ceb159..021a494ef3c5452e95c18f8da76e7db1a0eac5e0 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -107,6 +107,7 @@ pass_library(constant_folding_pass inference) pass_library(auto_mixed_precision_pass inference) pass_library(conv2d_fusion_layout_transfer_pass inference) pass_library(transfer_layout_elim_pass inference) +pass_library(relu6_fuse_pass inference) pass_library(silu_fuse_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) @@ -434,6 +435,10 @@ cc_test( test_delete_cast_op_pass SRCS delete_cast_op_pass_test.cc DEPS delete_cast_op_pass) +cc_test( + test_relu6_fuse_pass + SRCS relu6_fuse_pass_test.cc + DEPS relu6_fuse_pass) if(WITH_GPU OR WITH_ROCM) cc_test( diff --git a/paddle/fluid/framework/ir/relu6_fuse_pass.cc b/paddle/fluid/framework/ir/relu6_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4ee3716e5a42aba151558b4d61b13c18283a1d3 --- /dev/null +++ b/paddle/fluid/framework/ir/relu6_fuse_pass.cc @@ -0,0 +1,137 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/relu6_fuse_pass.h" + +#include +#include + +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +void Relu6FusePass::ApplyImpl(ir::Graph* graph) const { + // This pass is now used for xpu, because xpu can fuse conv + bias + relu6 + const std::string pattern_name = "relu6_fuse"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + + auto* clip_x = gpd.mutable_pattern() + ->NewNode("clip_x") + ->assert_is_op_input("clip", "X") + ->assert_var_not_persistable() + ->AsInput(); + auto clip_op = + gpd.mutable_pattern()->NewNode("clip_op")->assert_is_op("clip"); + auto clip_min = gpd.mutable_pattern() + ->NewNode("clip_min") + ->assert_is_op_input("clip", "Min") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }) + ->AsInput(); + auto clip_max = gpd.mutable_pattern() + ->NewNode("clip_max") + ->assert_is_op_input("clip", "Max") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 1; + }) + ->AsInput(); + auto clip_out = gpd.mutable_pattern() + ->NewNode("clip_out") + ->assert_is_op_output("clip") + ->AsOutput(); + + clip_op->LinksFrom({clip_x, clip_min, clip_max}).LinksTo({clip_out}); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + Node* clip_x_node = subgraph.at(clip_x); + Node* clip_op_node = subgraph.at(clip_op); + Node* clip_max_node = subgraph.at(clip_max); + Node* clip_min_node = subgraph.at(clip_min); + Node* clip_out_node = subgraph.at(clip_out); + + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + + const auto& clip_max_t = + scope->GetVar(clip_max_node->Name())->Get(); + auto clip_max_t_dims = clip_max_t.dims(); + PADDLE_ENFORCE_EQ( + clip_max_t_dims.size(), + 1, + platform::errors::InvalidArgument("the size(%d) of clip max tensor " + "must equal 1", + clip_max_t_dims.size())); + const auto& clip_min_t = + scope->GetVar(clip_min_node->Name())->Get(); + auto clip_min_t_dims = clip_min_t.dims(); + PADDLE_ENFORCE_EQ( + clip_min_t_dims.size(), + 1, + platform::errors::InvalidArgument("the size(%d) of clip max tensor " + "must equal 1", + clip_min_t_dims.size())); + auto tensor_type = clip_max_t.dtype(); + float max_val_ = 0.f; + float min_val_ = 1.f; + if (tensor_type == phi::DataType::FLOAT16) { + auto* clip_max_t_fp16_ptr = clip_max_t.data(); + auto* clip_min_t_fp16_ptr = clip_min_t.data(); + max_val_ = static_cast(clip_max_t_fp16_ptr[0]); + min_val_ = static_cast(clip_min_t_fp16_ptr[0]); + } else if (tensor_type == phi::DataType::FLOAT32) { + auto* clip_max_t_fp32_ptr = clip_max_t.data(); + auto* clip_min_t_fp32_ptr = clip_min_t.data(); + max_val_ = clip_max_t_fp32_ptr[0]; + min_val_ = clip_min_t_fp32_ptr[0]; + } else { + PADDLE_THROW(platform::errors::Unavailable( + "relu6_fuse_pass do not supported weight dtype. " + "we now only support fp32/fp16.")); + } + if (std::abs(max_val_ - 6.0) < 1e-3 && std::abs(min_val_ - 0.0) < 1e-3) { + OpDesc new_desc; + new_desc.SetType("relu6"); + new_desc.SetAttr("threshold", 6.f); + new_desc.SetInput("X", {clip_x_node->Name()}); + new_desc.SetOutput("Out", {clip_out_node->Name()}); + new_desc.Flush(); + + std::unordered_set del_node_set; + del_node_set.insert(clip_op_node); + del_node_set.insert(clip_max_node); + del_node_set.insert(clip_min_node); + GraphSafeRemoveNodes(graph, del_node_set); + + auto fused_node = graph->CreateOpNode(&new_desc); + IR_NODE_LINK_TO(clip_x_node, fused_node); + IR_NODE_LINK_TO(fused_node, clip_out_node); + } + }; + gpd(graph, handler); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(relu6_fuse_pass, paddle::framework::ir::Relu6FusePass); diff --git a/paddle/fluid/framework/ir/relu6_fuse_pass.h b/paddle/fluid/framework/ir/relu6_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..3e7c3f79948c493906dd5d3806e2004597ed42c7 --- /dev/null +++ b/paddle/fluid/framework/ir/relu6_fuse_pass.h @@ -0,0 +1,59 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; +/* +fuse fill_constant + clip block in to relu6 op +For example: +graph: + Min(0) Input Max(6.0) + \ | / + \ | / + clip + | + | + Output +------------------------------------------------------ +After the pass is applied: + Input + | + | + relu6 + | + | + Output +*/ + +class Relu6FusePass : public FusePassBase { + public: + virtual ~Relu6FusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"relu6_fuse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/relu6_fuse_pass_test.cc b/paddle/fluid/framework/ir/relu6_fuse_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..69b351477cf32e8ca1e951349aa10de802065920 --- /dev/null +++ b/paddle/fluid/framework/ir/relu6_fuse_pass_test.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +template +void AddVarToScope(Scope* param_scope, + const std::string& name, + const DDim& dims, + T value = 0) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + auto* data = cpu_ctx->Alloc(tensor); + for (int64_t i = 0; i < tensor->numel(); i++) { + data[i] = value; + } +} + +TEST(Relu6FusePass, basic) { + Layers layers; + + auto* in_x = layers.data("in_x", {1, 32, 112, 112}); + auto* clip_min = layers.data("clip_x", {1}, true); + auto* clip_max = layers.data("clip_y", {1}, true); + layers.clip(in_x, clip_min, clip_max); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto* param_scope = new Scope(); + graph->Set("__param_scope__", param_scope); + AddVarToScope(param_scope, clip_min->Name(), {1}, 0.f); + AddVarToScope(param_scope, clip_max->Name(), {1}, 6.f); + auto pass = PassRegistry::Instance().Get("relu6_fuse_pass"); + VLOG(3) << DebugString(graph); + + pass->Apply(graph.get()); + VLOG(3) << DebugString(graph); + + auto clip_num = GetNumOpNodes(graph, "clip"); + PADDLE_ENFORCE_EQ(clip_num, + 0, + platform::errors::PreconditionNotMet( + "clip should be mapped to relu6 after pass.")); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(relu6_fuse_pass); diff --git a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc index 205740c0e24d82057bbc5f13cb2569e8a5e1cb72..ac07447f0d3a59a339018239a382339685691d0d 100644 --- a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.cc @@ -139,6 +139,76 @@ Reshape2MatmulPattern::Reshape2MatmulPattern(PDPattern* pattern, reshape2->LinksFrom({reshape2_in}).LinksTo({matmul_x}); matmul->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out}); } + +struct Squeeze2MatmulPattern : public PatternBase { + Squeeze2MatmulPattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(squeeze2); + PATTERN_DECL_NODE(matmul); + // declare variable node's name + PATTERN_DECL_NODE(squeeze2_in); + PATTERN_DECL_NODE(matmul_x); + PATTERN_DECL_NODE(matmul_y); + PATTERN_DECL_NODE(matmul_out); +}; + +Squeeze2MatmulPattern::Squeeze2MatmulPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* squeeze2_in = + pattern->NewNode(squeeze2_in_repr()) + ->assert_is_op_input("squeeze2", "X") + ->AsInput() + ->assert_more([](Node* node) { + auto squeeze2_in_x_shape = node->Var()->GetShape(); + size_t squeeze2_in_rank = squeeze2_in_x_shape.size(); + bool nice_shape = + squeeze2_in_x_shape[2] == 1 && squeeze2_in_x_shape[3] == 1; + return squeeze2_in_rank == 4 && nice_shape; + }); + auto* squeeze2 = pattern->NewNode(squeeze2_repr()) + ->assert_is_op("squeeze2") + ->assert_has_n_inputs(1) + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto squeeze2_op_axes = + op_desc->GetAttrIfExists>("axes"); + return squeeze2_op_axes == std::vector{2, 3}; + }); + auto matmul_x = pattern->NewNode(matmul_x_repr()) + ->assert_is_op_output("squeeze2", "Out") + ->assert_has_n_outputs(1) + ->assert_is_op_input("matmul", "X") + ->assert_more([](Node* node) { + auto matmul_x_shape = node->Var()->GetShape(); + size_t matmul_x_rank = matmul_x_shape.size(); + return matmul_x_rank == 2; + }); + auto* matmul_y = pattern->NewNode(matmul_y_repr()) + ->assert_is_op_input("matmul", "Y") + ->assert_is_persistable_var() + ->assert_more([](Node* node) { + auto matmul_y_shape = node->Var()->GetShape(); + size_t matmul_y_rank = matmul_y_shape.size(); + return matmul_y_rank == 2; + }); + auto* matmul = pattern->NewNode(matmul_repr()) + ->assert_is_op("matmul") + ->assert_op_attr("transpose_X", false) + ->assert_op_attr("transpose_Y", false) + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto matmul_alpha_attr = + op_desc->GetAttrIfExists("alpha"); + return std::abs(matmul_alpha_attr - 1.f) < 1e-5; + }); + auto* matmul_out = pattern->NewNode(matmul_out_repr()) + ->assert_is_op_output("matmul", "Out") + ->AsOutput(); + squeeze2->LinksFrom({squeeze2_in}).LinksTo({matmul_x}); + matmul->LinksFrom({matmul_x, matmul_y}).LinksTo({matmul_out}); +} } // namespace patterns void Reshape2MatmulXPUFusePass::FuseReshape2Matmul(ir::Graph* graph) const { @@ -250,6 +320,59 @@ void MapMatmulV2ToMatmulXPUPass::ApplyImpl(ir::Graph* graph) const { MapMatmulV2ToMatmul(graph); } +void Squeeze2MatmulXPUFusePass::FuseSqueeze2Matmul(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::Squeeze2MatmulPattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle Squeeze2MatmulXPUFusePass"; + /* declare operator node's name */ + GET_IR_NODE(squeeze2); + GET_IR_NODE(matmul); + /* declare variable node's name*/ + GET_IR_NODE(squeeze2_in); + GET_IR_NODE(matmul_x); + GET_IR_NODE(matmul_y); + GET_IR_NODE(matmul_out); + + bool flag = true; + std::vector& next_ops = matmul_out->outputs; + flag = flag && next_ops.size() == 1 && + (next_ops[0]->Name() == "elementwise_add" || + next_ops[0]->Name() == "batch_norm"); + + if (flag) { + OpDesc desc(matmul->Op()->Block()); + desc.SetType("mul"); + desc.SetInput("X", {squeeze2_in->Name()}); + desc.SetInput("Y", {matmul_y->Name()}); + desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetAttr("x_num_col_dims", 1); + desc.SetAttr("y_num_col_dims", 1); + + auto mul_node = graph->CreateOpNode(&desc); + IR_NODE_LINK_TO(squeeze2_in, mul_node); + IR_NODE_LINK_TO(matmul_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_out); + GraphSafeRemoveNodes(graph, {squeeze2, matmul_x, matmul}); + found_subgraph_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void Squeeze2MatmulXPUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + FuseSqueeze2Matmul(graph); +} + } // namespace ir } // namespace framework } // namespace paddle @@ -272,3 +395,13 @@ REGISTER_PASS_CAPABILITY(map_matmulv2_to_matmul_xpu_pass) paddle::framework::compatible::OpVersionComparatorCombination() .EQ("matmul_v2", 0) .LE("matmul", 1)); + +REGISTER_PASS(squeeze2_matmul_xpu_fuse_pass, + paddle::framework::ir::Squeeze2MatmulXPUFusePass); + +REGISTER_PASS_CAPABILITY(squeeze2_matmul_xpu_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("squeeze2", 0) + .LE("matmul", 1) + .EQ("mul", 0)); diff --git a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.h b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.h index bc16ea9c4057c46905081cc5111b41c694d4b696..2282a0be29cebfdf5ffeed8b4703e2cce0cea0fb 100644 --- a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.h +++ b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass.h @@ -31,6 +31,15 @@ namespace paddle { namespace framework { namespace ir { +class Squeeze2MatmulXPUFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void FuseSqueeze2Matmul(ir::Graph* graph) const; + const std::string name_scope_{"squeeze2_matmul_xpu_fuse_pass"}; +}; + class Reshape2MatmulXPUFusePass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const override; diff --git a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass_test.cc b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass_test.cc index f81e783fc051985a47b693c91660b2d19309cdb5..17da6ef07620fba3e775b43cc9f7e7537c5395b9 100644 --- a/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass_test.cc +++ b/paddle/fluid/framework/ir/xpu/reshape2_matmul_xpu_fuse_pass_test.cc @@ -22,6 +22,32 @@ namespace paddle { namespace framework { namespace ir { +TEST(Squeeze2MatmulXPUFusePass, basic) { + Layers layers; + + auto* squeeze2_in = layers.data("squeeze2_in", {64, 1, 74, 1}); + auto* squeeze2_out = layers.squeeze2(squeeze2_in, std::vector{1, 3}); + auto* matmul_y = layers.data("matmul_y", {74, 64}, true); + auto* matmul_out = + layers.matmul(squeeze2_out, matmul_y, nullptr, false, false); + auto* ele_y = layers.data("ele_y", {64}, true); + layers.elementwise_add(matmul_out, ele_y); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("squeeze2_matmul_xpu_fuse_pass"); + VLOG(3) << DebugString(graph); + + pass->Apply(graph.get()); + VLOG(3) << DebugString(graph); + + auto ops_num = GetNumOpNodes(graph); + PADDLE_ENFORCE_EQ( + ops_num, + 3, + platform::errors::PreconditionNotMet( + "graph should only have 2 op nodes, but received %d.", ops_num)); +} + TEST(ReShape2MatmulXPUFusePass, basic) { Layers layers; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 0edcaac4335a2588b28f9b09b28fc07aa93238d7..95285f9930181d2213e798b2fc9cdb25802ebb6f 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -529,10 +529,12 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "xpu_delete_cast_op_pass", "stack_fuse_pass", "fused_multi_transformer_xpu_pass", + "relu6_fuse_pass", "sigmoid_elementmul_fuse_pass", "matmul_weight_trans_pass", "map_matmulv2_to_matmul_xpu_pass", "reshape2_matmul_xpu_fuse_pass", + "squeeze2_matmul_xpu_fuse_pass", "redundant_squeeze_unsqueeze_elimination_pass", "fc_xpu_fuse_pass", "conv2d_xpu_fuse_pass",