diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 88d4cd7a5e4ed6d494f1dd70d25c1c90fa1e822c..863bc63fc9f3fcd5d79fe745d5184d1732f15ccf 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -258,6 +258,7 @@ if(WITH_XPU) xpu DEPS ${XPU_PASS_DEPS}) pass_library(add_activation_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(xpu_delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fold_interp_outsize_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fold_two_squeeze2_fuse_pass inference DIR xpu DEPS @@ -548,6 +549,10 @@ if(WITH_XPU) test_multi_encoder_xpu_adaptive_seqlen_fuse_pass SRCS xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass_test.cc DEPS multi_encoder_xpu_adaptive_seqlen_fuse_pass) + cc_test( + test_xpu_delete_cast_op_pass + SRCS xpu/xpu_delete_cast_op_pass_test.cc + DEPS xpu_delete_cast_op_pass) cc_test( test_fold_interp_outsize_fuse_pass SRCS xpu/fold_interp_outsize_fuse_pass_test.cc diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc old mode 100644 new mode 100755 index e5ee2dc274930964e3e1a47aff42f9ecf8804655..c70a05829952c99685d540cb7b2d66c628c339da --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -67,6 +67,7 @@ static const std::vector xpu_support_subgraph_passes = { "one_beam_size_fuse_pass", "stack_fuse_pass", "fused_multi_transformer_xpu_pass", + "xpu_delete_cast_op_pass", "fc_xpu_fuse_pass", "link_xpu_op_max_pass", }; diff --git a/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.cc b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..f310d5b105b05c71d234064c34e04e2895638a0a --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.cc @@ -0,0 +1,257 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.h" +#include +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/cast_kernel.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +struct CastSoftmaxPattern : public PatternBase { + CastSoftmaxPattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(cast0); + PATTERN_DECL_NODE(softmax); + PATTERN_DECL_NODE(cast1); + // declare variable node's name + PATTERN_DECL_NODE(cast0_in); + PATTERN_DECL_NODE(cast0_out); + PATTERN_DECL_NODE(softmax_out); + PATTERN_DECL_NODE(cast1_out); +}; + +CastSoftmaxPattern::CastSoftmaxPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* cast0_in = + pattern->NewNode(cast0_in_repr())->assert_is_op_input("cast", "X"); + auto* cast0 = + pattern->NewNode(cast0_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == static_cast(proto::VarType::FP16) && + out_dtype == static_cast(proto::VarType::FP32); + }); + auto* cast0_out = pattern->NewNode(cast0_out_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("softmax", "X") + ->assert_has_n_outputs(1); + auto* softmax = pattern->NewNode(softmax_repr())->assert_is_op("softmax"); + auto* softmax_out = pattern->NewNode(softmax_out_repr()) + ->assert_is_op_output("softmax", "Out") + ->assert_is_op_input("cast", "X") + ->assert_has_n_outputs(1); + auto* cast1 = + pattern->NewNode(cast1_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == static_cast(proto::VarType::FP32) && + out_dtype == static_cast(proto::VarType::FP16); + }); + auto* cast1_out = + pattern->NewNode(cast1_out_repr())->assert_is_op_output("cast", "Out"); + + cast0->LinksFrom({cast0_in}).LinksTo({cast0_out}); + softmax->LinksFrom({cast0_out}).LinksTo({softmax_out}); + cast1->LinksFrom({softmax_out}).LinksTo({cast1_out}); +} +} // namespace patterns + +int XpuDeleteCastOpPass::ApplyCastSoftmaxPass(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::CastSoftmaxPattern pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyCastSoftmaxPass fuse"; + GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax, softmax, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast1, cast1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax_out, softmax_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast1_out, cast1_out, pattern); + + softmax->Op()->RenameInput(cast0_out->Name(), cast0_in->Name()); + softmax->Op()->RenameOutput(softmax_out->Name(), cast1_out->Name()); + IR_NODE_LINK_TO(cast0_in, softmax); + IR_NODE_LINK_TO(softmax, cast1_out); + + std::unordered_set delete_nodes{ + cast0, cast1, cast0_out, softmax_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +namespace patterns { +struct CastLayerNormPattern : public PatternBase { + CastLayerNormPattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(cast0); + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(cast1); + // declare variable node's name + PATTERN_DECL_NODE(cast0_in); + PATTERN_DECL_NODE(cast0_out); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(cast1_out); +}; + +CastLayerNormPattern::CastLayerNormPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* cast0_in = + pattern->NewNode(cast0_in_repr())->assert_is_op_input("cast", "X"); + auto* cast0 = + pattern->NewNode(cast0_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == static_cast(proto::VarType::FP16) && + out_dtype == static_cast(proto::VarType::FP32); + }); + auto* cast0_out = pattern->NewNode(cast0_out_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("layer_norm", "X") + ->assert_has_n_outputs(1); + auto* layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr()) + ->assert_is_op_output("layer_norm", "Y") + ->assert_is_op_input("cast", "X") + ->assert_has_n_outputs(1); + auto* cast1 = + pattern->NewNode(cast1_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == static_cast(proto::VarType::FP32) && + out_dtype == static_cast(proto::VarType::FP16); + }); + auto* cast1_out = + pattern->NewNode(cast1_out_repr())->assert_is_op_output("cast", "Out"); + + cast0->LinksFrom({cast0_in}).LinksTo({cast0_out}); + layer_norm->LinksFrom({cast0_out}).LinksTo({layer_norm_out}); + cast1->LinksFrom({layer_norm_out}).LinksTo({cast1_out}); +} +} // namespace patterns + +int XpuDeleteCastOpPass::ApplyCastLayerNormPass(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::CastLayerNormPattern pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyCastLayerNormPass fuse"; + GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast1, cast1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast1_out, cast1_out, pattern); + + layer_norm->Op()->RenameInput(cast0_out->Name(), cast0_in->Name()); + layer_norm->Op()->RenameOutput(layer_norm_out->Name(), cast1_out->Name()); + IR_NODE_LINK_TO(cast0_in, layer_norm); + IR_NODE_LINK_TO(layer_norm, cast1_out); + + std::unordered_set delete_nodes{ + cast0, cast1, cast0_out, layer_norm_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +void XpuDeleteCastOpPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + if (!graph->IsMainGraph()) { + VLOG(3) << "'xpu_delete_cast_op_pass' needs info in all " + "graphs, so it " + "should be applied in the main graph."; + return; + } + Init(name_scope_, graph); + + int found_subgraph_count = ApplyCastSoftmaxPass(graph); + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + found_subgraph_count += ApplyCastSoftmaxPass(graph->GetSubGraph(i)); + } + if (found_subgraph_count > 0) { + LOG(INFO) << "--- delete " << found_subgraph_count + << " cast_softmax_cast subgraph"; + } + + found_subgraph_count = 0; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + found_subgraph_count += ApplyCastLayerNormPass(graph->GetSubGraph(i)); + } + if (found_subgraph_count > 0) { + LOG(INFO) << "--- delete " << found_subgraph_count + << " cast_layer_norm_cast subgraph"; + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(xpu_delete_cast_op_pass, + paddle::framework::ir::XpuDeleteCastOpPass); + +REGISTER_PASS_CAPABILITY(xpu_delete_cast_op_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "cast", 0)); diff --git a/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.h b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.h new file mode 100755 index 0000000000000000000000000000000000000000..d0556e8b0bf0afc0df1105ec4ab63aca29781b3d --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.h @@ -0,0 +1,70 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +class XpuDeleteCastOpPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + /* + Origin subgraph: + cast(fp16->fp32) + | + softmax + | + cast(fp32->fp16) + + Optimized subgraph: + softmax + */ + int ApplyCastSoftmaxPass(ir::Graph* graph) const; + + /* + Origin subgraph: + cast(fp16->fp32) + | + layer_norm + | + cast(fp32->fp16) + + Optimized subgraph: + layer_norm + */ + int ApplyCastLayerNormPass(ir::Graph* graph) const; + + const std::string name_scope_{"xpu_delete_cast_op_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass_test.cc b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8682bdca9f8e4d01a0003bea3737e6edafd7f7b --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass_test.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +VarDesc* Data(paddle::framework::BlockDesc* block, + std::string name, + std::vector shape = {}, + bool is_persistable = false, + proto::VarType::Type data_type = proto::VarType::FP32) { + auto* var = block->Var(name); + var->SetType(proto::VarType::LOD_TENSOR); + var->SetDataType(data_type); + var->SetShape(shape); + var->SetPersistable(is_persistable); + return var; +} + +VarDesc* AddCast(BlockDesc* block, + VarDesc* input, + int in_dtype = 5, + int out_dtype = 5) { + VarDesc* out = Data(block, input->Name() + "_out"); + OpDesc* op = block->AppendOp(); + op->SetType("cast"); + op->SetInput("X", {input->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr("in_dtype", in_dtype); + op->SetAttr("out_dtype", out_dtype); + return out; +} + +int GetOpNum(Graph* graph, std::string op_type = "") { + int num_nodes = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op() && + (node->Op()->Type() == op_type || op_type.empty())) { + num_nodes++; + } + } + return num_nodes; +} + +TEST(ApplyCastSoftmaxPass, basic) { + paddle::framework::ProgramDesc program; + auto* block = program.MutableBlock(0); + auto* cast0_in = Data(block, "cast0_in", {1}); + auto* cast0_out = AddCast(block, cast0_in, 4, 5); + auto* softmax_out = Data(block, "softmax_out", {1}); + OpDesc* softmax = block->AppendOp(); + softmax->SetType("softmax"); + softmax->SetInput("X", {cast0_out->Name()}); + softmax->SetOutput("Out", {softmax_out->Name()}); + AddCast(block, softmax_out, 5, 4); + + std::unique_ptr graph(new ir::Graph(program)); + auto scope = new Scope(); + graph->Set("__param_scope__", scope); + auto pass = PassRegistry::Instance().Get("xpu_delete_cast_op_pass"); + pass->Apply(graph.get()); + int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast"); + PADDLE_ENFORCE_EQ( + GetOpNum(graph->GetSubGraph(0), "cast"), + 0, + platform::errors::PreconditionNotMet( + "graph should have 0 cast after xpu_delete_cast_op_pass, " + "but actually has %d.", + cast_num_in_graph)); +} + +TEST(ApplyCastLayerNormPass, basic) { + paddle::framework::ProgramDesc program; + auto* block = program.MutableBlock(0); + auto* cast0_in = Data(block, "cast0_in", {1}); + auto* cast0_out = AddCast(block, cast0_in, 4, 5); + auto* layer_norm_out = Data(block, "layer_norm_out", {1}); + OpDesc* layer_norm = block->AppendOp(); + layer_norm->SetType("layer_norm"); + layer_norm->SetInput("X", {cast0_out->Name()}); + layer_norm->SetOutput("Y", {layer_norm_out->Name()}); + AddCast(block, layer_norm_out, 5, 4); + + std::unique_ptr graph(new ir::Graph(program)); + auto scope = new Scope(); + graph->Set("__param_scope__", scope); + auto pass = PassRegistry::Instance().Get("xpu_delete_cast_op_pass"); + pass->Apply(graph.get()); + int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast"); + PADDLE_ENFORCE_EQ( + GetOpNum(graph->GetSubGraph(0), "cast"), + 0, + platform::errors::PreconditionNotMet( + "graph should have 0 cast after xpu_delete_cast_op_pass, " + "but actually has %d.", + cast_num_in_graph)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(xpu_delete_cast_op_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc old mode 100644 new mode 100755 index b8832132044dbedf31c5bd660a15db9b7a1f8bc7..f9372ad3082050a0002edc6535bebdefeca59ee7 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -538,6 +538,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "link_xpu_op_max_pass", "inplace_op_var_pass", "delete_isolated_node_pass", + "xpu_delete_cast_op_pass", }); use_xpu_ = true; }