diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 75bcdb18209175cbf25b5f04ebedfbe405aa3fd2..2292baa996d4c31b05d0b8060cbe230b719c5932 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -240,6 +240,7 @@ if(WITH_XPU) pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) endif() cc_library( @@ -518,4 +519,8 @@ if(WITH_XPU) test_stack_fuse_pass SRCS xpu/stack_fuse_pass_test.cc DEPS stack_fuse_pass) + cc_test( + test_delete_cast_op_pass + SRCS xpu/delete_cast_op_pass_test.cc + DEPS delete_cast_op_pass) endif() diff --git a/paddle/fluid/framework/ir/xpu/delete_cast_op_pass.cc b/paddle/fluid/framework/ir/xpu/delete_cast_op_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb417322476b2b2324e8ab33cfa1ca813e974611 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/delete_cast_op_pass.cc @@ -0,0 +1,614 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/delete_cast_op_pass.h" +#include +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +struct CastWritePattern : public PatternBase { + CastWritePattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(cast0); + PATTERN_DECL_NODE(write_to_array); + // declare variable node's name + PATTERN_DECL_NODE(cast0_in); + PATTERN_DECL_NODE(cast0_out); + PATTERN_DECL_NODE(write_to_array_out); +}; + +CastWritePattern::CastWritePattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* cast0_in = + pattern->NewNode(cast0_in_repr())->assert_is_op_input("cast", "X"); + auto* cast0 = + pattern->NewNode(cast0_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == static_cast(proto::VarType::FP16) && + out_dtype == static_cast(proto::VarType::FP32); + }); + auto* cast0_out = pattern->NewNode(cast0_out_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("write_to_array", "X") + ->assert_has_n_outputs(1); + auto* write_to_array = + pattern->NewNode(write_to_array_repr())->assert_is_op("write_to_array"); + auto* write_to_array_out = pattern->NewNode(write_to_array_out_repr()) + ->assert_is_op_output("write_to_array", "Out"); + + cast0->LinksFrom({cast0_in}).LinksTo({cast0_out}); + write_to_array->LinksFrom({cast0_out}).LinksTo({write_to_array_out}); +} +} // namespace patterns + +static std::vector FindOpNodeWithInputName( + ir::Graph* graph, const std::string& input_name) { + std::vector ret; + for (auto* node : graph->Nodes()) { + if (!node->IsOp()) continue; + auto inputs = node->Op()->Inputs(); + bool find_input = false; + for (auto input : inputs) { + auto input_names = input.second; + if (std::count(input_names.begin(), input_names.end(), input_name) > 0) { + find_input = true; + break; + } + } + if (find_input) ret.push_back(node); + } + return ret; +} + +static std::vector FindOpNodeWithOutputName( + ir::Graph* graph, const std::string& output_name) { + std::vector ret; + for (auto* node : graph->Nodes()) { + if (!node->IsOp()) continue; + auto outputs = node->Op()->Outputs(); + bool find_output = false; + for (auto output : outputs) { + auto output_names = output.second; + if (std::count(output_names.begin(), output_names.end(), output_name) > + 0) { + find_output = true; + break; + } + } + if (find_output) ret.push_back(node); + } + return ret; +} + +int DeleteCastOpPass::ApplyCastWriteReadPass(ir::Graph* graph) const { + if (graph->SubGraphsSize() != 2) { + VLOG(3) << "ApplyCastWriteReadPass only support 2 subgraphs."; + return 0; + } + auto* graph0 = graph->GetSubGraph(0); + auto* graph1 = graph->GetSubGraph(1); + GraphPatternDetector gpd; + patterns::CastWritePattern pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyCastWriteReadPass fuse"; + GET_IR_NODE(cast0); + GET_IR_NODE(write_to_array); + GET_IR_NODE(cast0_in); + GET_IR_NODE(cast0_out); + GET_IR_NODE(write_to_array_out); + + // write_to_array_out(in graph1) may not link to any op nodes, so we fine + // read_from_array by write_to_array_out name. + auto write_out_op_nodes = + FindOpNodeWithInputName(graph, write_to_array_out->Name()); + if (write_out_op_nodes.size() != 1 || + write_out_op_nodes[0]->Op()->Type() != "read_from_array") + return; + Node* read_from_array = write_out_op_nodes[0]; + Node* read_from_array_out = read_from_array->outputs[0]; + auto read_out_op_nodes = + FindOpNodeWithInputName(graph, read_from_array_out->Name()); + if (read_out_op_nodes.size() != 1 || + read_out_op_nodes[0]->Op()->Type() != "cast") + return; + Node* cast1 = read_out_op_nodes[0]; + Node* cast1_out = cast1->outputs[0]; + + // find nodes in graph0 + auto nodes_in_graph0 = + FindOpNodeWithOutputName(graph0, write_to_array_out->Name()); + if (nodes_in_graph0.size() != 2) return; + Node* write_to_array_0 = nullptr; + Node* while_op = nullptr; + for (auto* node : nodes_in_graph0) { + if (node->Name() == "write_to_array") { + write_to_array_0 = node; + } else if (node->Name() == "while") { + while_op = node; + } + } + if (write_to_array_0 == nullptr || while_op == nullptr) return; + + // modify graph0 + Node* write_to_array_0_x = nullptr; + auto write_to_array_0_x_name = write_to_array_0->Op()->Input("X")[0]; + for (auto* node : write_to_array_0->inputs) { + if (node->Name() == write_to_array_0_x_name) { + write_to_array_0_x = node; + break; + } + } + + std::string cast_out_name = write_to_array_0_x_name + "_fp16"; + VarDesc cast_out_desc(cast_out_name); + cast_out_desc.SetShape(write_to_array_0_x->Var()->GetShape()); + cast_out_desc.SetDataType(proto::VarType::Type::VarType_Type_FP16); + auto* cast_out = graph0->CreateVarNode(&cast_out_desc); + + auto* block = write_to_array_0->Op()->Block(); + framework::OpDesc cast_op_desc(block); + cast_op_desc.SetType("cast"); + cast_op_desc.SetInput("X", {write_to_array_0_x_name}); + cast_op_desc.SetAttr("in_dtype", 5); + cast_op_desc.SetAttr("out_dtype", 4); + cast_op_desc.SetOutput("Out", {cast_out_name}); + auto* cast = graph0->CreateOpNode(&cast_op_desc); + + write_to_array_0->Op()->RenameInput(write_to_array_0_x_name, cast_out_name); + + IR_NODE_UNLINK(write_to_array_0_x, write_to_array_0); + IR_NODE_LINK_TO(write_to_array_0_x, cast); + IR_NODE_LINK_TO(cast, cast_out); + IR_NODE_LINK_TO(cast_out, write_to_array_0); + + // modify graph1 + write_to_array->Op()->RenameInput(cast0_out->Name(), cast0_in->Name()); + read_from_array->Op()->RenameOutput(read_from_array_out->Name(), + cast1_out->Name()); + IR_NODE_LINK_TO(cast0, write_to_array); + IR_NODE_LINK_TO(read_from_array_out, cast1_out); + + std::unordered_set delete_nodes{ + cast0, cast1, cast0_out, read_from_array_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + + found_subgraph_count++; + }; + + gpd(graph1, handler); + return found_subgraph_count; +} + +namespace patterns { +struct CastLodResetWritePattern : public PatternBase { + CastLodResetWritePattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(cast0); + PATTERN_DECL_NODE(lod_reset); + PATTERN_DECL_NODE(write_to_array); + // declare variable node's name + PATTERN_DECL_NODE(cast0_in); + PATTERN_DECL_NODE(cast0_out); + PATTERN_DECL_NODE(lod_reset_out); + PATTERN_DECL_NODE(write_to_array_out); +}; + +CastLodResetWritePattern::CastLodResetWritePattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* cast0_in = + pattern->NewNode(cast0_in_repr())->assert_is_op_input("cast", "X"); + auto* cast0 = + pattern->NewNode(cast0_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == static_cast(proto::VarType::FP16) && + out_dtype == static_cast(proto::VarType::FP32); + }); + auto* cast0_out = pattern->NewNode(cast0_out_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("lod_reset", "X") + ->assert_has_n_outputs(1); + auto* lod_reset = + pattern->NewNode(lod_reset_repr())->assert_is_op("lod_reset"); + auto* lod_reset_out = pattern->NewNode(lod_reset_out_repr()) + ->assert_is_op_output("lod_reset", "Out") + ->assert_is_op_input("write_to_array", "X") + ->assert_has_n_outputs(1); + auto* write_to_array = + pattern->NewNode(write_to_array_repr())->assert_is_op("write_to_array"); + auto* write_to_array_out = pattern->NewNode(write_to_array_out_repr()) + ->assert_is_op_output("write_to_array", "Out"); + + cast0->LinksFrom({cast0_in}).LinksTo({cast0_out}); + lod_reset->LinksFrom({cast0_out}).LinksTo({lod_reset_out}); + write_to_array->LinksFrom({lod_reset_out}).LinksTo({write_to_array_out}); +} +} // namespace patterns + +int DeleteCastOpPass::ApplyCastLodResetWriteReadPass(ir::Graph* graph) const { + if (graph->SubGraphsSize() != 2) { + VLOG(3) << "ApplyCastLodResetWriteReadPass only support 2 subgraphs."; + return 0; + } + auto* graph0 = graph->GetSubGraph(0); + auto* graph1 = graph->GetSubGraph(1); + GraphPatternDetector gpd; + patterns::CastLodResetWritePattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyCastLodResetWriteReadPass fuse"; + GET_IR_NODE(cast0); + GET_IR_NODE(lod_reset); + GET_IR_NODE(write_to_array); + GET_IR_NODE(cast0_in); + GET_IR_NODE(cast0_out); + GET_IR_NODE(lod_reset_out); + GET_IR_NODE(write_to_array_out); + + // write_to_array_out(in graph1) may not link to any op nodes, so we fine + // read_from_array by write_to_array_out name. + auto write_out_op_nodes = + FindOpNodeWithInputName(graph, write_to_array_out->Name()); + if (write_out_op_nodes.size() != 1 || + write_out_op_nodes[0]->Op()->Type() != "read_from_array") + return; + Node* read_from_array = write_out_op_nodes[0]; + Node* read_from_array_out = read_from_array->outputs[0]; + auto read_out_op_nodes = + FindOpNodeWithInputName(graph, read_from_array_out->Name()); + if (read_out_op_nodes.size() != 1 || + read_out_op_nodes[0]->Op()->Type() != "cast") + return; + Node* cast1 = read_out_op_nodes[0]; + Node* cast1_out = cast1->outputs[0]; + + // find nodes in graph0 + auto nodes_in_graph0 = + FindOpNodeWithOutputName(graph0, write_to_array_out->Name()); + if (nodes_in_graph0.size() != 2) return; + Node* write_to_array_0 = nullptr; + Node* while_op = nullptr; + for (auto* node : nodes_in_graph0) { + if (node->Name() == "write_to_array") { + write_to_array_0 = node; + } else if (node->Name() == "while") { + while_op = node; + } + } + if (write_to_array_0 == nullptr || while_op == nullptr) return; + + nodes_in_graph0 = + FindOpNodeWithInputName(graph0, write_to_array_out->Name()); + if (nodes_in_graph0.size() != 2) return; + Node* beam_search_decode = nullptr; + while_op = nullptr; + for (auto* node : nodes_in_graph0) { + if (node->Name() == "beam_search_decode") { + beam_search_decode = node; + } else if (node->Name() == "while") { + while_op = node; + } + } + if (beam_search_decode == nullptr || while_op == nullptr) return; + + // modify graph0: 1. insert cast before write_to_array_0 + Node* write_to_array_0_x = nullptr; + auto write_to_array_0_x_name = write_to_array_0->Op()->Input("X")[0]; + for (auto* node : write_to_array_0->inputs) { + if (node->Name() == write_to_array_0_x_name) { + write_to_array_0_x = node; + break; + } + } + + std::string cast_out_name = write_to_array_0_x_name + "_fp16"; + VarDesc cast_out_desc(cast_out_name); + cast_out_desc.SetShape(write_to_array_0_x->Var()->GetShape()); + cast_out_desc.SetDataType(proto::VarType::Type::VarType_Type_FP16); + auto* cast_out = graph0->CreateVarNode(&cast_out_desc); + + auto* block = write_to_array_0->Op()->Block(); + framework::OpDesc cast_op_desc(block); + cast_op_desc.SetType("cast"); + cast_op_desc.SetInput("X", {write_to_array_0_x_name}); + cast_op_desc.SetAttr("in_dtype", 5); + cast_op_desc.SetAttr("out_dtype", 4); + cast_op_desc.SetOutput("Out", {cast_out_name}); + auto* cast = graph0->CreateOpNode(&cast_op_desc); + + write_to_array_0->Op()->RenameInput(write_to_array_0_x_name, cast_out_name); + IR_NODE_UNLINK(write_to_array_0_x, write_to_array_0); + IR_NODE_LINK_TO(write_to_array_0_x, cast); + IR_NODE_LINK_TO(cast, cast_out); + IR_NODE_LINK_TO(cast_out, write_to_array_0); + + // modify graph0: 2. insert cast after beam_search_decode + Node* beam_search_decode_out_score = nullptr; + for (auto* node : beam_search_decode->outputs) { + if (node->Name() == + beam_search_decode->Op()->Output("SentenceScores")[0]) { + beam_search_decode_out_score = node; + break; + } + } + + std::string cast_in_name = beam_search_decode_out_score->Name() + "_fp16"; + VarDesc cast_in_desc(cast_in_name); + cast_in_desc.SetShape(beam_search_decode_out_score->Var()->GetShape()); + cast_in_desc.SetDataType(proto::VarType::Type::VarType_Type_FP16); + auto* cast_in = graph0->CreateVarNode(&cast_in_desc); + + cast_op_desc = framework::OpDesc(block); + cast_op_desc.SetType("cast"); + cast_op_desc.SetInput("X", {cast_in_name}); + cast_op_desc.SetAttr("in_dtype", 4); + cast_op_desc.SetAttr("out_dtype", 5); + cast_op_desc.SetOutput("Out", {beam_search_decode_out_score->Name()}); + cast = graph0->CreateOpNode(&cast_op_desc); + + beam_search_decode->Op()->RenameOutput(beam_search_decode_out_score->Name(), + cast_in_name); + IR_NODE_UNLINK(beam_search_decode, beam_search_decode_out_score); + IR_NODE_LINK_TO(beam_search_decode, cast_in); + IR_NODE_LINK_TO(cast_in, cast); + IR_NODE_LINK_TO(cast, beam_search_decode_out_score); + + // modify graph1 + lod_reset->Op()->RenameInput(cast0_out->Name(), cast0_in->Name()); + read_from_array->Op()->RenameOutput(read_from_array_out->Name(), + cast1_out->Name()); + IR_NODE_LINK_TO(cast0, lod_reset); + IR_NODE_LINK_TO(read_from_array_out, cast1_out); + + std::unordered_set delete_nodes{ + cast0, cast1, cast0_out, read_from_array_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + + found_subgraph_count++; + }; + + gpd(graph1, handler); + return found_subgraph_count; +} + +namespace patterns { +struct CastIndexSamplePattern : public PatternBase { + CastIndexSamplePattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(cast0); + PATTERN_DECL_NODE(index_sample); + PATTERN_DECL_NODE(cast1); + // declare variable node's name + PATTERN_DECL_NODE(cast0_in); + PATTERN_DECL_NODE(cast0_out); + PATTERN_DECL_NODE(index_sample_out); + PATTERN_DECL_NODE(cast1_out); +}; + +CastIndexSamplePattern::CastIndexSamplePattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* cast0_in = + pattern->NewNode(cast0_in_repr())->assert_is_op_input("cast", "X"); + auto* cast0 = + pattern->NewNode(cast0_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == static_cast(proto::VarType::FP16) && + out_dtype == static_cast(proto::VarType::FP32); + }); + auto* cast0_out = pattern->NewNode(cast0_out_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("index_sample", "X") + ->assert_has_n_outputs(1); + auto* index_sample = + pattern->NewNode(index_sample_repr())->assert_is_op("index_sample"); + auto* index_sample_out = pattern->NewNode(index_sample_out_repr()) + ->assert_is_op_output("index_sample", "Out") + ->assert_is_op_input("cast", "X") + ->assert_has_n_outputs(1); + auto* cast1 = + pattern->NewNode(cast1_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == static_cast(proto::VarType::FP32) && + out_dtype == static_cast(proto::VarType::FP16); + }); + auto* cast1_out = + pattern->NewNode(cast1_out_repr())->assert_is_op_output("cast", "Out"); + + cast0->LinksFrom({cast0_in}).LinksTo({cast0_out}); + index_sample->LinksFrom({cast0_out}).LinksTo({index_sample_out}); + cast1->LinksFrom({index_sample_out}).LinksTo({cast1_out}); +} +} // namespace patterns + +int DeleteCastOpPass::ApplyCastIndexSamplePass(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::CastIndexSamplePattern pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyCastIndexSamplePass fuse"; + GET_IR_NODE(cast0); + GET_IR_NODE(index_sample); + GET_IR_NODE(cast1); + GET_IR_NODE(cast0_in); + GET_IR_NODE(cast0_out); + GET_IR_NODE(index_sample_out); + GET_IR_NODE(cast1_out); + + index_sample->Op()->RenameInput(cast0_out->Name(), cast0_in->Name()); + index_sample->Op()->RenameOutput(index_sample_out->Name(), + cast1_out->Name()); + IR_NODE_LINK_TO(cast0_in, index_sample); + IR_NODE_LINK_TO(index_sample, cast1_out); + + std::unordered_set delete_nodes{ + cast0, cast1, cast0_out, index_sample_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +namespace patterns { +struct CastPattern : public PatternBase { + CastPattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(cast); + // declare variable node's name + PATTERN_DECL_NODE(cast_in); + PATTERN_DECL_NODE(cast_out); +}; + +CastPattern::CastPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* cast_in = + pattern->NewNode(cast_in_repr())->assert_is_op_input("cast", "X"); + auto* cast = pattern->NewNode(cast_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = + op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == out_dtype; + }); + auto* cast_out = + pattern->NewNode(cast_out_repr())->assert_is_op_output("cast", "Out"); + + cast->LinksFrom({cast_in}).LinksTo({cast_out}); +} +} // namespace patterns + +int DeleteCastOpPass::ApplyCastPass(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::CastPattern pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyCastPass fuse"; + GET_IR_NODE(cast); + GET_IR_NODE(cast_in); + GET_IR_NODE(cast_out); + for (auto* out_op_node : cast_out->outputs) { + out_op_node->Op()->RenameInput(cast_out->Name(), cast_in->Name()); + IR_NODE_LINK_TO(cast_in, out_op_node); + } + std::unordered_set delete_nodes{cast, cast_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +void DeleteCastOpPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + if (!graph->IsMainGraph()) { + VLOG(3) << "'delete_cast_op_pass' needs info in all graphs, so it " + "should be applied in the main graph."; + return; + } + Init(name_scope_, graph); + + int found_subgraph_count = ApplyCastWriteReadPass(graph); + if (found_subgraph_count > 0) { + LOG(INFO) << "--- delete " << found_subgraph_count + << " cast_write_read_cast subgraph"; + } + + found_subgraph_count = ApplyCastLodResetWriteReadPass(graph); + if (found_subgraph_count > 0) { + LOG(INFO) << "--- delete " << found_subgraph_count + << " cast_lod_reset_write_read_cast subgraph"; + } + + found_subgraph_count = 0; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + found_subgraph_count += ApplyCastIndexSamplePass(graph->GetSubGraph(i)); + } + if (found_subgraph_count > 0) { + LOG(INFO) << "--- delete " << found_subgraph_count + << " cast_index_sample_cast subgraph"; + } + + found_subgraph_count = 0; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + found_subgraph_count += ApplyCastPass(graph->GetSubGraph(i)); + } + if (found_subgraph_count > 0) { + LOG(INFO) << "--- delete " << found_subgraph_count + << " cast(with same in/out dtype) subgraph"; + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_cast_op_pass, paddle::framework::ir::DeleteCastOpPass); + +REGISTER_PASS_CAPABILITY(delete_cast_op_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "cast", 0)); diff --git a/paddle/fluid/framework/ir/xpu/delete_cast_op_pass.h b/paddle/fluid/framework/ir/xpu/delete_cast_op_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..f0010a851f722520f4140ec87c7bd7bdf128cdb3 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/delete_cast_op_pass.h @@ -0,0 +1,122 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +class DeleteCastOpPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + /* + Origin subgraph: + main_graph: while subgraph: + + write_to_array cast(fp16->fp32) + | | + (write_var:fp32) write_to_array + | + (write_var:fp32) + | + read_from_array + | + cast(fp32->fp16) + + Optimized subgraph: + main_graph: while subgraph: + + cast write_to_array + | | + write_to_array (write_var:fp16) + | | + (write_var:fp16) read_from_array + */ + int ApplyCastWriteReadPass(ir::Graph* graph) const; + + /* + Origin subgraph: + main_graph: while subgraph: + + write_to_array cast(fp16->fp32) + | | + (write_var:fp32) lod_reset + | | + while write_to_array + | | + (write_var:fp32) (write_var:fp32) + | | + beam_search_decode read_from_array + | | + (out_score:fp32) cast(fp32->fp16) + + Optimized subgraph: + main_graph: while subgraph: + + cast lod_reset + | | + write_to_array write_to_array + | | + (write_var:fp16) (write_var:fp16) + | | + while read_from_array + | + (write_var:fp16) + | + beam_search_decode + | + cast(fp16->fp32) + | + (out_score:fp32) + */ + int ApplyCastLodResetWriteReadPass(ir::Graph* graph) const; + + /* + Origin subgraph: + cast(fp16->fp32) + | + index_sample + | + cast(fp32->fp16) + + Optimized subgraph: + index_sample + */ + int ApplyCastIndexSamplePass(ir::Graph* graph) const; + + // Delete cast if its "in_dtype" is the same with "out_dtype" + int ApplyCastPass(ir::Graph* graph) const; + + const std::string name_scope_{"delete_cast_op_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/delete_cast_op_pass_test.cc b/paddle/fluid/framework/ir/xpu/delete_cast_op_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..570eae7825e35c5116cf5f937a8b93630cb49cba --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/delete_cast_op_pass_test.cc @@ -0,0 +1,252 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +VarDesc* Data(paddle::framework::BlockDesc* block, + std::string name, + std::vector shape = {}, + bool is_persistable = false, + proto::VarType::Type data_type = proto::VarType::FP32) { + auto* var = block->Var(name); + var->SetType(proto::VarType::LOD_TENSOR); + var->SetDataType(data_type); + var->SetShape(shape); + var->SetPersistable(is_persistable); + return var; +} + +VarDesc* AddWriteToArray(BlockDesc* block, + std::vector x, + VarDesc* i, + VarDesc* out = nullptr) { + if (out == nullptr) { + out = Data(block, x[0]->Name() + "_out"); + } + OpDesc* op = block->AppendOp(); + op->SetType("write_to_array"); + std::vector x_names; + for (auto k : x) { + x_names.push_back(k->Name()); + } + op->SetInput("X", x_names); + op->SetInput("I", {i->Name()}); + op->SetOutput("Out", {out->Name()}); + return out; +} + +VarDesc* AddReadFromArray(BlockDesc* block, VarDesc* x, VarDesc* i) { + auto* out = Data(block, x->Name() + "_out"); + OpDesc* op = block->AppendOp(); + op->SetType("read_from_array"); + op->SetInput("X", {x->Name()}); + op->SetInput("I", {i->Name()}); + op->SetOutput("Out", {out->Name()}); + return out; +} + +VarDesc* AddCast(BlockDesc* block, + VarDesc* input, + int in_dtype = 5, + int out_dtype = 5) { + VarDesc* out = Data(block, input->Name() + "_out"); + OpDesc* op = block->AppendOp(); + op->SetType("cast"); + op->SetInput("X", {input->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr("in_dtype", in_dtype); + op->SetAttr("out_dtype", out_dtype); + return out; +} + +VarDesc* AddLodReset(BlockDesc* block, VarDesc* input) { + VarDesc* out = Data(block, input->Name() + "_out"); + OpDesc* op = block->AppendOp(); + op->SetType("lod_reset"); + op->SetInput("X", {input->Name()}); + op->SetOutput("Out", {out->Name()}); + return out; +} + +std::vector AddBeamSearchDecode(BlockDesc* block, + VarDesc* ids, + VarDesc* scores) { + VarDesc* out_ids = Data(block, ids->Name() + "_out"); + VarDesc* out_scores = Data(block, scores->Name() + "_out"); + OpDesc* op = block->AppendOp(); + op->SetType("beam_search_decode"); + op->SetInput("Ids", {ids->Name()}); + op->SetInput("Scores", {scores->Name()}); + op->SetOutput("SentenceIds", {out_ids->Name()}); + op->SetOutput("SentenceScores", {out_scores->Name()}); + return {out_ids, out_scores}; +} + +int GetOpNum(Graph* graph, std::string op_type = "") { + int num_nodes = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op() && + (node->Op()->Type() == op_type || op_type.empty())) { + num_nodes++; + } + } + return num_nodes; +} + +TEST(ApplyCastWriteReadPass, basic) { + paddle::framework::ProgramDesc program; + auto* block0 = program.MutableBlock(0); + auto* block1 = program.AppendBlock(*block0); + auto* write_0_x = Data(block0, "write_0_x", {1}); + auto* write_0_i = Data(block0, "write_0_i", {1}); + auto* write_0_out = AddWriteToArray(block0, {write_0_x}, write_0_i); + OpDesc* while_loop = block0->AppendOp(); + while_loop->SetType("while"); + while_loop->SetInput("X", {write_0_out->Name()}); + while_loop->SetOutput("Out", {write_0_out->Name()}); + + auto* cast_1_0_in = Data(block1, "cast_1_0", {1}); + auto* cast_1_0_out = AddCast(block1, cast_1_0_in, 4, 5); + auto* write_1_i = Data(block1, "write_1_i", {1}); + auto* write_1_out = Data(block1, write_0_out->Name(), {1}); + AddWriteToArray(block1, {cast_1_0_out}, write_1_i, write_1_out); + auto* read_1_i = Data(block1, "read_1_i", {1}); + auto* read_1_out = AddReadFromArray(block1, write_1_out, read_1_i); + AddCast(block1, read_1_out, 5, 4); + + std::unique_ptr graph(new ir::Graph(program)); + auto scope = new Scope(); + graph->Set("__param_scope__", scope); + auto pass = PassRegistry::Instance().Get("delete_cast_op_pass"); + pass->Apply(graph.get()); + + int cast_num_in_graph1 = GetOpNum(graph->GetSubGraph(1), "cast"); + PADDLE_ENFORCE_EQ(cast_num_in_graph1, + 0, + platform::errors::PreconditionNotMet( + "graph1 should have 0 cast after delete_cast_op_pass, " + "but actually has %d.", + cast_num_in_graph1)); + int cast_num_in_graph0 = GetOpNum(graph.get(), "cast"); + PADDLE_ENFORCE_EQ(cast_num_in_graph0, + 1, + platform::errors::PreconditionNotMet( + "graph0 should have 1 cast after delete_cast_op_pass, " + "but actually has %d.", + cast_num_in_graph0)); +} + +TEST(ApplyCastLodResetWriteReadPass, basic) { + paddle::framework::ProgramDesc program; + auto* block0 = program.MutableBlock(0); + auto* block1 = program.AppendBlock(*block0); + + auto* write_0_x = Data(block0, "write_0_x", {1}); + auto* write_0_i = Data(block0, "write_0_i", {1}); + auto* write_0_out = AddWriteToArray(block0, {write_0_x}, write_0_i); + OpDesc* while_loop = block0->AppendOp(); + while_loop->SetType("while"); + while_loop->SetInput("X", {write_0_out->Name()}); + while_loop->SetOutput("Out", {write_0_out->Name()}); + auto* ids = Data(block0, "ids", {1}); + AddBeamSearchDecode(block0, ids, write_0_out); + + auto* cast_1_0_in = Data(block1, "cast_1_0", {1}); + auto* cast_1_0_out = AddCast(block1, cast_1_0_in, 4, 5); + auto* lod_reset_out = AddLodReset(block1, cast_1_0_out); + auto* write_1_i = Data(block1, "write_1_i", {1}); + auto* write_1_out = Data(block1, write_0_out->Name(), {1}); + AddWriteToArray(block1, {lod_reset_out}, write_1_i, write_1_out); + auto* read_1_i = Data(block1, "read_1_i", {1}); + auto* read_1_out = AddReadFromArray(block1, write_1_out, read_1_i); + AddCast(block1, read_1_out, 5, 4); + + std::unique_ptr graph(new ir::Graph(program)); + auto scope = new Scope(); + graph->Set("__param_scope__", scope); + auto pass = PassRegistry::Instance().Get("delete_cast_op_pass"); + pass->Apply(graph.get()); + + int cast_num_in_graph1 = GetOpNum(graph->GetSubGraph(1), "cast"); + PADDLE_ENFORCE_EQ(cast_num_in_graph1, + 0, + platform::errors::PreconditionNotMet( + "graph1 should have 0 cast after delete_cast_op_pass, " + "but actually has %d.", + cast_num_in_graph1)); + int cast_num_in_graph0 = GetOpNum(graph.get(), "cast"); + PADDLE_ENFORCE_EQ(cast_num_in_graph0, + 2, + platform::errors::PreconditionNotMet( + "graph0 should have 2 cast after delete_cast_op_pass, " + "but actually has %d.", + cast_num_in_graph0)); +} + +TEST(ApplyCastIndexSamplePass, basic) { + paddle::framework::ProgramDesc program; + auto* block = program.MutableBlock(0); + auto* cast0_in = Data(block, "cast0_in", {1}); + auto* cast0_out = AddCast(block, cast0_in, 4, 5); + auto* index_sample_out = Data(block, "index_sample_out", {1}); + OpDesc* index_sample = block->AppendOp(); + index_sample->SetType("index_sample"); + index_sample->SetInput("X", {cast0_out->Name()}); + index_sample->SetOutput("Out", {index_sample_out->Name()}); + AddCast(block, index_sample_out, 5, 4); + + std::unique_ptr graph(new ir::Graph(program)); + auto scope = new Scope(); + graph->Set("__param_scope__", scope); + auto pass = PassRegistry::Instance().Get("delete_cast_op_pass"); + pass->Apply(graph.get()); + int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast"); + PADDLE_ENFORCE_EQ(GetOpNum(graph->GetSubGraph(0), "cast"), + 0, + platform::errors::PreconditionNotMet( + "graph should have 0 cast after delete_cast_op_pass, " + "but actually has %d.", + cast_num_in_graph)); +} + +TEST(ApplyCastPass, basic) { + paddle::framework::ProgramDesc program; + auto* block = program.MutableBlock(0); + auto* cast0_in = Data(block, "cast0_in", {1}); + AddCast(block, cast0_in, 3, 3); + std::unique_ptr graph(new ir::Graph(program)); + auto scope = new Scope(); + graph->Set("__param_scope__", scope); + auto pass = PassRegistry::Instance().Get("delete_cast_op_pass"); + pass->Apply(graph.get()); + int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast"); + PADDLE_ENFORCE_EQ(GetOpNum(graph->GetSubGraph(0), "cast"), + 0, + platform::errors::PreconditionNotMet( + "graph should have 0 cast after delete_cast_op_pass, " + "but actually has %d.", + cast_num_in_graph)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(delete_cast_op_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 99a6f336ff3704de5dfe5f32c3167ef580d7bb64..2610d6de2fa62f542f6aa4ce378eb48ea80a608b 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -528,6 +528,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "multi_encoder_xpu_fuse_pass", "multi_encoder_xpu_slice_fuse_pass", "one_beam_size_fuse_pass", + "delete_cast_op_pass", "stack_fuse_pass", "fused_multi_transformer_xpu_quant_pass", "fc_xpu_fuse_pass", diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 3c22660f8e4f3758cbbd595a3355d8fd0fdd3899..e4d4ee2ce9319021d2f9886307347d9d5b4a0b4c 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -249,6 +249,7 @@ REGISTER_OPERATOR(lod_reset_grad, REGISTER_OP_CPU_KERNEL( lod_reset, + ops::LoDResetKernel, ops::LoDResetKernel, ops::LoDResetKernel, ops::LoDResetKernel, @@ -257,6 +258,8 @@ REGISTER_OP_CPU_KERNEL( #ifdef PADDLE_WITH_XPU REGISTER_OP_XPU_KERNEL( lod_reset, + ops::LoDResetKernel, ops::LoDResetKernel, ops::LoDResetKernel, ops::LoDResetKernel, @@ -265,6 +268,8 @@ REGISTER_OP_XPU_KERNEL( REGISTER_OP_CPU_KERNEL( lod_reset_grad, + ops::LoDResetGradKernel, ops::LoDResetGradKernel, ops::LoDResetGradKernel, ops::LoDResetGradKernel,