未验证 提交 8b622d58 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] add delete_cast_op_pass (#52305)

上级 3e2d0195
......@@ -240,6 +240,7 @@ if(WITH_XPU)
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -518,4 +519,8 @@ if(WITH_XPU)
test_stack_fuse_pass
SRCS xpu/stack_fuse_pass_test.cc
DEPS stack_fuse_pass)
cc_test(
test_delete_cast_op_pass
SRCS xpu/delete_cast_op_pass_test.cc
DEPS delete_cast_op_pass)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/delete_cast_op_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct CastWritePattern : public PatternBase {
CastWritePattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(cast0);
PATTERN_DECL_NODE(write_to_array);
// declare variable node's name
PATTERN_DECL_NODE(cast0_in);
PATTERN_DECL_NODE(cast0_out);
PATTERN_DECL_NODE(write_to_array_out);
};
CastWritePattern::CastWritePattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* cast0_in =
pattern->NewNode(cast0_in_repr())->assert_is_op_input("cast", "X");
auto* cast0 =
pattern->NewNode(cast0_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP16) &&
out_dtype == static_cast<int>(proto::VarType::FP32);
});
auto* cast0_out = pattern->NewNode(cast0_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("write_to_array", "X")
->assert_has_n_outputs(1);
auto* write_to_array =
pattern->NewNode(write_to_array_repr())->assert_is_op("write_to_array");
auto* write_to_array_out = pattern->NewNode(write_to_array_out_repr())
->assert_is_op_output("write_to_array", "Out");
cast0->LinksFrom({cast0_in}).LinksTo({cast0_out});
write_to_array->LinksFrom({cast0_out}).LinksTo({write_to_array_out});
}
} // namespace patterns
static std::vector<Node*> FindOpNodeWithInputName(
ir::Graph* graph, const std::string& input_name) {
std::vector<Node*> ret;
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) continue;
auto inputs = node->Op()->Inputs();
bool find_input = false;
for (auto input : inputs) {
auto input_names = input.second;
if (std::count(input_names.begin(), input_names.end(), input_name) > 0) {
find_input = true;
break;
}
}
if (find_input) ret.push_back(node);
}
return ret;
}
static std::vector<Node*> FindOpNodeWithOutputName(
ir::Graph* graph, const std::string& output_name) {
std::vector<Node*> ret;
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) continue;
auto outputs = node->Op()->Outputs();
bool find_output = false;
for (auto output : outputs) {
auto output_names = output.second;
if (std::count(output_names.begin(), output_names.end(), output_name) >
0) {
find_output = true;
break;
}
}
if (find_output) ret.push_back(node);
}
return ret;
}
int DeleteCastOpPass::ApplyCastWriteReadPass(ir::Graph* graph) const {
if (graph->SubGraphsSize() != 2) {
VLOG(3) << "ApplyCastWriteReadPass only support 2 subgraphs.";
return 0;
}
auto* graph0 = graph->GetSubGraph(0);
auto* graph1 = graph->GetSubGraph(1);
GraphPatternDetector gpd;
patterns::CastWritePattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastWriteReadPass fuse";
GET_IR_NODE(cast0);
GET_IR_NODE(write_to_array);
GET_IR_NODE(cast0_in);
GET_IR_NODE(cast0_out);
GET_IR_NODE(write_to_array_out);
// write_to_array_out(in graph1) may not link to any op nodes, so we fine
// read_from_array by write_to_array_out name.
auto write_out_op_nodes =
FindOpNodeWithInputName(graph, write_to_array_out->Name());
if (write_out_op_nodes.size() != 1 ||
write_out_op_nodes[0]->Op()->Type() != "read_from_array")
return;
Node* read_from_array = write_out_op_nodes[0];
Node* read_from_array_out = read_from_array->outputs[0];
auto read_out_op_nodes =
FindOpNodeWithInputName(graph, read_from_array_out->Name());
if (read_out_op_nodes.size() != 1 ||
read_out_op_nodes[0]->Op()->Type() != "cast")
return;
Node* cast1 = read_out_op_nodes[0];
Node* cast1_out = cast1->outputs[0];
// find nodes in graph0
auto nodes_in_graph0 =
FindOpNodeWithOutputName(graph0, write_to_array_out->Name());
if (nodes_in_graph0.size() != 2) return;
Node* write_to_array_0 = nullptr;
Node* while_op = nullptr;
for (auto* node : nodes_in_graph0) {
if (node->Name() == "write_to_array") {
write_to_array_0 = node;
} else if (node->Name() == "while") {
while_op = node;
}
}
if (write_to_array_0 == nullptr || while_op == nullptr) return;
// modify graph0
Node* write_to_array_0_x = nullptr;
auto write_to_array_0_x_name = write_to_array_0->Op()->Input("X")[0];
for (auto* node : write_to_array_0->inputs) {
if (node->Name() == write_to_array_0_x_name) {
write_to_array_0_x = node;
break;
}
}
std::string cast_out_name = write_to_array_0_x_name + "_fp16";
VarDesc cast_out_desc(cast_out_name);
cast_out_desc.SetShape(write_to_array_0_x->Var()->GetShape());
cast_out_desc.SetDataType(proto::VarType::Type::VarType_Type_FP16);
auto* cast_out = graph0->CreateVarNode(&cast_out_desc);
auto* block = write_to_array_0->Op()->Block();
framework::OpDesc cast_op_desc(block);
cast_op_desc.SetType("cast");
cast_op_desc.SetInput("X", {write_to_array_0_x_name});
cast_op_desc.SetAttr("in_dtype", 5);
cast_op_desc.SetAttr("out_dtype", 4);
cast_op_desc.SetOutput("Out", {cast_out_name});
auto* cast = graph0->CreateOpNode(&cast_op_desc);
write_to_array_0->Op()->RenameInput(write_to_array_0_x_name, cast_out_name);
IR_NODE_UNLINK(write_to_array_0_x, write_to_array_0);
IR_NODE_LINK_TO(write_to_array_0_x, cast);
IR_NODE_LINK_TO(cast, cast_out);
IR_NODE_LINK_TO(cast_out, write_to_array_0);
// modify graph1
write_to_array->Op()->RenameInput(cast0_out->Name(), cast0_in->Name());
read_from_array->Op()->RenameOutput(read_from_array_out->Name(),
cast1_out->Name());
IR_NODE_LINK_TO(cast0, write_to_array);
IR_NODE_LINK_TO(read_from_array_out, cast1_out);
std::unordered_set<const Node*> delete_nodes{
cast0, cast1, cast0_out, read_from_array_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph1, handler);
return found_subgraph_count;
}
namespace patterns {
struct CastLodResetWritePattern : public PatternBase {
CastLodResetWritePattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(cast0);
PATTERN_DECL_NODE(lod_reset);
PATTERN_DECL_NODE(write_to_array);
// declare variable node's name
PATTERN_DECL_NODE(cast0_in);
PATTERN_DECL_NODE(cast0_out);
PATTERN_DECL_NODE(lod_reset_out);
PATTERN_DECL_NODE(write_to_array_out);
};
CastLodResetWritePattern::CastLodResetWritePattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* cast0_in =
pattern->NewNode(cast0_in_repr())->assert_is_op_input("cast", "X");
auto* cast0 =
pattern->NewNode(cast0_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP16) &&
out_dtype == static_cast<int>(proto::VarType::FP32);
});
auto* cast0_out = pattern->NewNode(cast0_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("lod_reset", "X")
->assert_has_n_outputs(1);
auto* lod_reset =
pattern->NewNode(lod_reset_repr())->assert_is_op("lod_reset");
auto* lod_reset_out = pattern->NewNode(lod_reset_out_repr())
->assert_is_op_output("lod_reset", "Out")
->assert_is_op_input("write_to_array", "X")
->assert_has_n_outputs(1);
auto* write_to_array =
pattern->NewNode(write_to_array_repr())->assert_is_op("write_to_array");
auto* write_to_array_out = pattern->NewNode(write_to_array_out_repr())
->assert_is_op_output("write_to_array", "Out");
cast0->LinksFrom({cast0_in}).LinksTo({cast0_out});
lod_reset->LinksFrom({cast0_out}).LinksTo({lod_reset_out});
write_to_array->LinksFrom({lod_reset_out}).LinksTo({write_to_array_out});
}
} // namespace patterns
int DeleteCastOpPass::ApplyCastLodResetWriteReadPass(ir::Graph* graph) const {
if (graph->SubGraphsSize() != 2) {
VLOG(3) << "ApplyCastLodResetWriteReadPass only support 2 subgraphs.";
return 0;
}
auto* graph0 = graph->GetSubGraph(0);
auto* graph1 = graph->GetSubGraph(1);
GraphPatternDetector gpd;
patterns::CastLodResetWritePattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastLodResetWriteReadPass fuse";
GET_IR_NODE(cast0);
GET_IR_NODE(lod_reset);
GET_IR_NODE(write_to_array);
GET_IR_NODE(cast0_in);
GET_IR_NODE(cast0_out);
GET_IR_NODE(lod_reset_out);
GET_IR_NODE(write_to_array_out);
// write_to_array_out(in graph1) may not link to any op nodes, so we fine
// read_from_array by write_to_array_out name.
auto write_out_op_nodes =
FindOpNodeWithInputName(graph, write_to_array_out->Name());
if (write_out_op_nodes.size() != 1 ||
write_out_op_nodes[0]->Op()->Type() != "read_from_array")
return;
Node* read_from_array = write_out_op_nodes[0];
Node* read_from_array_out = read_from_array->outputs[0];
auto read_out_op_nodes =
FindOpNodeWithInputName(graph, read_from_array_out->Name());
if (read_out_op_nodes.size() != 1 ||
read_out_op_nodes[0]->Op()->Type() != "cast")
return;
Node* cast1 = read_out_op_nodes[0];
Node* cast1_out = cast1->outputs[0];
// find nodes in graph0
auto nodes_in_graph0 =
FindOpNodeWithOutputName(graph0, write_to_array_out->Name());
if (nodes_in_graph0.size() != 2) return;
Node* write_to_array_0 = nullptr;
Node* while_op = nullptr;
for (auto* node : nodes_in_graph0) {
if (node->Name() == "write_to_array") {
write_to_array_0 = node;
} else if (node->Name() == "while") {
while_op = node;
}
}
if (write_to_array_0 == nullptr || while_op == nullptr) return;
nodes_in_graph0 =
FindOpNodeWithInputName(graph0, write_to_array_out->Name());
if (nodes_in_graph0.size() != 2) return;
Node* beam_search_decode = nullptr;
while_op = nullptr;
for (auto* node : nodes_in_graph0) {
if (node->Name() == "beam_search_decode") {
beam_search_decode = node;
} else if (node->Name() == "while") {
while_op = node;
}
}
if (beam_search_decode == nullptr || while_op == nullptr) return;
// modify graph0: 1. insert cast before write_to_array_0
Node* write_to_array_0_x = nullptr;
auto write_to_array_0_x_name = write_to_array_0->Op()->Input("X")[0];
for (auto* node : write_to_array_0->inputs) {
if (node->Name() == write_to_array_0_x_name) {
write_to_array_0_x = node;
break;
}
}
std::string cast_out_name = write_to_array_0_x_name + "_fp16";
VarDesc cast_out_desc(cast_out_name);
cast_out_desc.SetShape(write_to_array_0_x->Var()->GetShape());
cast_out_desc.SetDataType(proto::VarType::Type::VarType_Type_FP16);
auto* cast_out = graph0->CreateVarNode(&cast_out_desc);
auto* block = write_to_array_0->Op()->Block();
framework::OpDesc cast_op_desc(block);
cast_op_desc.SetType("cast");
cast_op_desc.SetInput("X", {write_to_array_0_x_name});
cast_op_desc.SetAttr("in_dtype", 5);
cast_op_desc.SetAttr("out_dtype", 4);
cast_op_desc.SetOutput("Out", {cast_out_name});
auto* cast = graph0->CreateOpNode(&cast_op_desc);
write_to_array_0->Op()->RenameInput(write_to_array_0_x_name, cast_out_name);
IR_NODE_UNLINK(write_to_array_0_x, write_to_array_0);
IR_NODE_LINK_TO(write_to_array_0_x, cast);
IR_NODE_LINK_TO(cast, cast_out);
IR_NODE_LINK_TO(cast_out, write_to_array_0);
// modify graph0: 2. insert cast after beam_search_decode
Node* beam_search_decode_out_score = nullptr;
for (auto* node : beam_search_decode->outputs) {
if (node->Name() ==
beam_search_decode->Op()->Output("SentenceScores")[0]) {
beam_search_decode_out_score = node;
break;
}
}
std::string cast_in_name = beam_search_decode_out_score->Name() + "_fp16";
VarDesc cast_in_desc(cast_in_name);
cast_in_desc.SetShape(beam_search_decode_out_score->Var()->GetShape());
cast_in_desc.SetDataType(proto::VarType::Type::VarType_Type_FP16);
auto* cast_in = graph0->CreateVarNode(&cast_in_desc);
cast_op_desc = framework::OpDesc(block);
cast_op_desc.SetType("cast");
cast_op_desc.SetInput("X", {cast_in_name});
cast_op_desc.SetAttr("in_dtype", 4);
cast_op_desc.SetAttr("out_dtype", 5);
cast_op_desc.SetOutput("Out", {beam_search_decode_out_score->Name()});
cast = graph0->CreateOpNode(&cast_op_desc);
beam_search_decode->Op()->RenameOutput(beam_search_decode_out_score->Name(),
cast_in_name);
IR_NODE_UNLINK(beam_search_decode, beam_search_decode_out_score);
IR_NODE_LINK_TO(beam_search_decode, cast_in);
IR_NODE_LINK_TO(cast_in, cast);
IR_NODE_LINK_TO(cast, beam_search_decode_out_score);
// modify graph1
lod_reset->Op()->RenameInput(cast0_out->Name(), cast0_in->Name());
read_from_array->Op()->RenameOutput(read_from_array_out->Name(),
cast1_out->Name());
IR_NODE_LINK_TO(cast0, lod_reset);
IR_NODE_LINK_TO(read_from_array_out, cast1_out);
std::unordered_set<const Node*> delete_nodes{
cast0, cast1, cast0_out, read_from_array_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph1, handler);
return found_subgraph_count;
}
namespace patterns {
struct CastIndexSamplePattern : public PatternBase {
CastIndexSamplePattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(cast0);
PATTERN_DECL_NODE(index_sample);
PATTERN_DECL_NODE(cast1);
// declare variable node's name
PATTERN_DECL_NODE(cast0_in);
PATTERN_DECL_NODE(cast0_out);
PATTERN_DECL_NODE(index_sample_out);
PATTERN_DECL_NODE(cast1_out);
};
CastIndexSamplePattern::CastIndexSamplePattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* cast0_in =
pattern->NewNode(cast0_in_repr())->assert_is_op_input("cast", "X");
auto* cast0 =
pattern->NewNode(cast0_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP16) &&
out_dtype == static_cast<int>(proto::VarType::FP32);
});
auto* cast0_out = pattern->NewNode(cast0_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("index_sample", "X")
->assert_has_n_outputs(1);
auto* index_sample =
pattern->NewNode(index_sample_repr())->assert_is_op("index_sample");
auto* index_sample_out = pattern->NewNode(index_sample_out_repr())
->assert_is_op_output("index_sample", "Out")
->assert_is_op_input("cast", "X")
->assert_has_n_outputs(1);
auto* cast1 =
pattern->NewNode(cast1_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP32) &&
out_dtype == static_cast<int>(proto::VarType::FP16);
});
auto* cast1_out =
pattern->NewNode(cast1_out_repr())->assert_is_op_output("cast", "Out");
cast0->LinksFrom({cast0_in}).LinksTo({cast0_out});
index_sample->LinksFrom({cast0_out}).LinksTo({index_sample_out});
cast1->LinksFrom({index_sample_out}).LinksTo({cast1_out});
}
} // namespace patterns
int DeleteCastOpPass::ApplyCastIndexSamplePass(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::CastIndexSamplePattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastIndexSamplePass fuse";
GET_IR_NODE(cast0);
GET_IR_NODE(index_sample);
GET_IR_NODE(cast1);
GET_IR_NODE(cast0_in);
GET_IR_NODE(cast0_out);
GET_IR_NODE(index_sample_out);
GET_IR_NODE(cast1_out);
index_sample->Op()->RenameInput(cast0_out->Name(), cast0_in->Name());
index_sample->Op()->RenameOutput(index_sample_out->Name(),
cast1_out->Name());
IR_NODE_LINK_TO(cast0_in, index_sample);
IR_NODE_LINK_TO(index_sample, cast1_out);
std::unordered_set<const Node*> delete_nodes{
cast0, cast1, cast0_out, index_sample_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
namespace patterns {
struct CastPattern : public PatternBase {
CastPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(cast);
// declare variable node's name
PATTERN_DECL_NODE(cast_in);
PATTERN_DECL_NODE(cast_out);
};
CastPattern::CastPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* cast_in =
pattern->NewNode(cast_in_repr())->assert_is_op_input("cast", "X");
auto* cast = pattern->NewNode(cast_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype =
op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == out_dtype;
});
auto* cast_out =
pattern->NewNode(cast_out_repr())->assert_is_op_output("cast", "Out");
cast->LinksFrom({cast_in}).LinksTo({cast_out});
}
} // namespace patterns
int DeleteCastOpPass::ApplyCastPass(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::CastPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastPass fuse";
GET_IR_NODE(cast);
GET_IR_NODE(cast_in);
GET_IR_NODE(cast_out);
for (auto* out_op_node : cast_out->outputs) {
out_op_node->Op()->RenameInput(cast_out->Name(), cast_in->Name());
IR_NODE_LINK_TO(cast_in, out_op_node);
}
std::unordered_set<const Node*> delete_nodes{cast, cast_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void DeleteCastOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
if (!graph->IsMainGraph()) {
VLOG(3) << "'delete_cast_op_pass' needs info in all graphs, so it "
"should be applied in the main graph.";
return;
}
Init(name_scope_, graph);
int found_subgraph_count = ApplyCastWriteReadPass(graph);
if (found_subgraph_count > 0) {
LOG(INFO) << "--- delete " << found_subgraph_count
<< " cast_write_read_cast subgraph";
}
found_subgraph_count = ApplyCastLodResetWriteReadPass(graph);
if (found_subgraph_count > 0) {
LOG(INFO) << "--- delete " << found_subgraph_count
<< " cast_lod_reset_write_read_cast subgraph";
}
found_subgraph_count = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
found_subgraph_count += ApplyCastIndexSamplePass(graph->GetSubGraph(i));
}
if (found_subgraph_count > 0) {
LOG(INFO) << "--- delete " << found_subgraph_count
<< " cast_index_sample_cast subgraph";
}
found_subgraph_count = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
found_subgraph_count += ApplyCastPass(graph->GetSubGraph(i));
}
if (found_subgraph_count > 0) {
LOG(INFO) << "--- delete " << found_subgraph_count
<< " cast(with same in/out dtype) subgraph";
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_cast_op_pass, paddle::framework::ir::DeleteCastOpPass);
REGISTER_PASS_CAPABILITY(delete_cast_op_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"cast", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class DeleteCastOpPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
/*
Origin subgraph:
main_graph: while subgraph:
write_to_array cast(fp16->fp32)
| |
(write_var:fp32) write_to_array
|
(write_var:fp32)
|
read_from_array
|
cast(fp32->fp16)
Optimized subgraph:
main_graph: while subgraph:
cast write_to_array
| |
write_to_array (write_var:fp16)
| |
(write_var:fp16) read_from_array
*/
int ApplyCastWriteReadPass(ir::Graph* graph) const;
/*
Origin subgraph:
main_graph: while subgraph:
write_to_array cast(fp16->fp32)
| |
(write_var:fp32) lod_reset
| |
while write_to_array
| |
(write_var:fp32) (write_var:fp32)
| |
beam_search_decode read_from_array
| |
(out_score:fp32) cast(fp32->fp16)
Optimized subgraph:
main_graph: while subgraph:
cast lod_reset
| |
write_to_array write_to_array
| |
(write_var:fp16) (write_var:fp16)
| |
while read_from_array
|
(write_var:fp16)
|
beam_search_decode
|
cast(fp16->fp32)
|
(out_score:fp32)
*/
int ApplyCastLodResetWriteReadPass(ir::Graph* graph) const;
/*
Origin subgraph:
cast(fp16->fp32)
|
index_sample
|
cast(fp32->fp16)
Optimized subgraph:
index_sample
*/
int ApplyCastIndexSamplePass(ir::Graph* graph) const;
// Delete cast if its "in_dtype" is the same with "out_dtype"
int ApplyCastPass(ir::Graph* graph) const;
const std::string name_scope_{"delete_cast_op_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
VarDesc* Data(paddle::framework::BlockDesc* block,
std::string name,
std::vector<int64_t> shape = {},
bool is_persistable = false,
proto::VarType::Type data_type = proto::VarType::FP32) {
auto* var = block->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
var->SetDataType(data_type);
var->SetShape(shape);
var->SetPersistable(is_persistable);
return var;
}
VarDesc* AddWriteToArray(BlockDesc* block,
std::vector<VarDesc*> x,
VarDesc* i,
VarDesc* out = nullptr) {
if (out == nullptr) {
out = Data(block, x[0]->Name() + "_out");
}
OpDesc* op = block->AppendOp();
op->SetType("write_to_array");
std::vector<std::string> x_names;
for (auto k : x) {
x_names.push_back(k->Name());
}
op->SetInput("X", x_names);
op->SetInput("I", {i->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* AddReadFromArray(BlockDesc* block, VarDesc* x, VarDesc* i) {
auto* out = Data(block, x->Name() + "_out");
OpDesc* op = block->AppendOp();
op->SetType("read_from_array");
op->SetInput("X", {x->Name()});
op->SetInput("I", {i->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* AddCast(BlockDesc* block,
VarDesc* input,
int in_dtype = 5,
int out_dtype = 5) {
VarDesc* out = Data(block, input->Name() + "_out");
OpDesc* op = block->AppendOp();
op->SetType("cast");
op->SetInput("X", {input->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("in_dtype", in_dtype);
op->SetAttr("out_dtype", out_dtype);
return out;
}
VarDesc* AddLodReset(BlockDesc* block, VarDesc* input) {
VarDesc* out = Data(block, input->Name() + "_out");
OpDesc* op = block->AppendOp();
op->SetType("lod_reset");
op->SetInput("X", {input->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
std::vector<VarDesc*> AddBeamSearchDecode(BlockDesc* block,
VarDesc* ids,
VarDesc* scores) {
VarDesc* out_ids = Data(block, ids->Name() + "_out");
VarDesc* out_scores = Data(block, scores->Name() + "_out");
OpDesc* op = block->AppendOp();
op->SetType("beam_search_decode");
op->SetInput("Ids", {ids->Name()});
op->SetInput("Scores", {scores->Name()});
op->SetOutput("SentenceIds", {out_ids->Name()});
op->SetOutput("SentenceScores", {out_scores->Name()});
return {out_ids, out_scores};
}
int GetOpNum(Graph* graph, std::string op_type = "") {
int num_nodes = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
(node->Op()->Type() == op_type || op_type.empty())) {
num_nodes++;
}
}
return num_nodes;
}
TEST(ApplyCastWriteReadPass, basic) {
paddle::framework::ProgramDesc program;
auto* block0 = program.MutableBlock(0);
auto* block1 = program.AppendBlock(*block0);
auto* write_0_x = Data(block0, "write_0_x", {1});
auto* write_0_i = Data(block0, "write_0_i", {1});
auto* write_0_out = AddWriteToArray(block0, {write_0_x}, write_0_i);
OpDesc* while_loop = block0->AppendOp();
while_loop->SetType("while");
while_loop->SetInput("X", {write_0_out->Name()});
while_loop->SetOutput("Out", {write_0_out->Name()});
auto* cast_1_0_in = Data(block1, "cast_1_0", {1});
auto* cast_1_0_out = AddCast(block1, cast_1_0_in, 4, 5);
auto* write_1_i = Data(block1, "write_1_i", {1});
auto* write_1_out = Data(block1, write_0_out->Name(), {1});
AddWriteToArray(block1, {cast_1_0_out}, write_1_i, write_1_out);
auto* read_1_i = Data(block1, "read_1_i", {1});
auto* read_1_out = AddReadFromArray(block1, write_1_out, read_1_i);
AddCast(block1, read_1_out, 5, 4);
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph1 = GetOpNum(graph->GetSubGraph(1), "cast");
PADDLE_ENFORCE_EQ(cast_num_in_graph1,
0,
platform::errors::PreconditionNotMet(
"graph1 should have 0 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph1));
int cast_num_in_graph0 = GetOpNum(graph.get(), "cast");
PADDLE_ENFORCE_EQ(cast_num_in_graph0,
1,
platform::errors::PreconditionNotMet(
"graph0 should have 1 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph0));
}
TEST(ApplyCastLodResetWriteReadPass, basic) {
paddle::framework::ProgramDesc program;
auto* block0 = program.MutableBlock(0);
auto* block1 = program.AppendBlock(*block0);
auto* write_0_x = Data(block0, "write_0_x", {1});
auto* write_0_i = Data(block0, "write_0_i", {1});
auto* write_0_out = AddWriteToArray(block0, {write_0_x}, write_0_i);
OpDesc* while_loop = block0->AppendOp();
while_loop->SetType("while");
while_loop->SetInput("X", {write_0_out->Name()});
while_loop->SetOutput("Out", {write_0_out->Name()});
auto* ids = Data(block0, "ids", {1});
AddBeamSearchDecode(block0, ids, write_0_out);
auto* cast_1_0_in = Data(block1, "cast_1_0", {1});
auto* cast_1_0_out = AddCast(block1, cast_1_0_in, 4, 5);
auto* lod_reset_out = AddLodReset(block1, cast_1_0_out);
auto* write_1_i = Data(block1, "write_1_i", {1});
auto* write_1_out = Data(block1, write_0_out->Name(), {1});
AddWriteToArray(block1, {lod_reset_out}, write_1_i, write_1_out);
auto* read_1_i = Data(block1, "read_1_i", {1});
auto* read_1_out = AddReadFromArray(block1, write_1_out, read_1_i);
AddCast(block1, read_1_out, 5, 4);
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph1 = GetOpNum(graph->GetSubGraph(1), "cast");
PADDLE_ENFORCE_EQ(cast_num_in_graph1,
0,
platform::errors::PreconditionNotMet(
"graph1 should have 0 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph1));
int cast_num_in_graph0 = GetOpNum(graph.get(), "cast");
PADDLE_ENFORCE_EQ(cast_num_in_graph0,
2,
platform::errors::PreconditionNotMet(
"graph0 should have 2 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph0));
}
TEST(ApplyCastIndexSamplePass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
auto* cast0_in = Data(block, "cast0_in", {1});
auto* cast0_out = AddCast(block, cast0_in, 4, 5);
auto* index_sample_out = Data(block, "index_sample_out", {1});
OpDesc* index_sample = block->AppendOp();
index_sample->SetType("index_sample");
index_sample->SetInput("X", {cast0_out->Name()});
index_sample->SetOutput("Out", {index_sample_out->Name()});
AddCast(block, index_sample_out, 5, 4);
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast");
PADDLE_ENFORCE_EQ(GetOpNum(graph->GetSubGraph(0), "cast"),
0,
platform::errors::PreconditionNotMet(
"graph should have 0 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph));
}
TEST(ApplyCastPass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
auto* cast0_in = Data(block, "cast0_in", {1});
AddCast(block, cast0_in, 3, 3);
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast");
PADDLE_ENFORCE_EQ(GetOpNum(graph->GetSubGraph(0), "cast"),
0,
platform::errors::PreconditionNotMet(
"graph should have 0 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(delete_cast_op_pass);
......@@ -528,6 +528,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
"one_beam_size_fuse_pass",
"delete_cast_op_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass",
......
......@@ -249,6 +249,7 @@ REGISTER_OPERATOR(lod_reset_grad,
REGISTER_OP_CPU_KERNEL(
lod_reset,
ops::LoDResetKernel<paddle::platform::CPUPlace, paddle::platform::float16>,
ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
ops::LoDResetKernel<paddle::platform::CPUPlace, int>,
......@@ -257,6 +258,8 @@ REGISTER_OP_CPU_KERNEL(
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(
lod_reset,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, float>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, double>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int>,
......@@ -265,6 +268,8 @@ REGISTER_OP_XPU_KERNEL(
REGISTER_OP_CPU_KERNEL(
lod_reset_grad,
ops::LoDResetGradKernel<paddle::platform::CPUPlace,
paddle::platform::float16>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册