未验证 提交 610a47dd 编写于 作者: X xinxinZi 提交者: GitHub

add xpu_optimize_cachekv_initialization_pass (#54809)

上级 5a804830
...@@ -214,6 +214,180 @@ int XpuDeleteCastOpPass::ApplyCastLayerNormPass(ir::Graph* graph) const { ...@@ -214,6 +214,180 @@ int XpuDeleteCastOpPass::ApplyCastLayerNormPass(ir::Graph* graph) const {
return found_subgraph_count; return found_subgraph_count;
} }
namespace patterns {
struct CastCacheKVInitializationPattern : public PatternBase {
CastCacheKVInitializationPattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(shape0);
PATTERN_DECL_NODE(shape1);
PATTERN_DECL_NODE(slice0);
PATTERN_DECL_NODE(slice1);
PATTERN_DECL_NODE(cast0);
PATTERN_DECL_NODE(elementwise_add);
PATTERN_DECL_NODE(scale);
PATTERN_DECL_NODE(cast1);
PATTERN_DECL_NODE(fill_constant);
// declare variable node's name
PATTERN_DECL_NODE(shape_in);
PATTERN_DECL_NODE(shape0_out);
PATTERN_DECL_NODE(slice0_out);
PATTERN_DECL_NODE(shape1_out);
PATTERN_DECL_NODE(slice1_out);
PATTERN_DECL_NODE(cast0_out);
PATTERN_DECL_NODE(elementwise_add_in0);
PATTERN_DECL_NODE(elementwise_add_out);
PATTERN_DECL_NODE(scale_out);
PATTERN_DECL_NODE(cast1_out);
};
CastCacheKVInitializationPattern::CastCacheKVInitializationPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* shape_in =
pattern->NewNode(shape_in_repr())->assert_is_op_input("shape", "X");
auto* shape0 = pattern->NewNode(shape0_repr())->assert_is_op("shape");
auto* shape0_out = pattern->NewNode(shape0_out_repr())
->assert_is_op_output("shape", "Out")
->assert_is_op_input("slice", "X");
auto* slice0 = pattern->NewNode(slice0_repr())->assert_is_op("slice");
auto* slice0_out = pattern->NewNode(slice0_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_op_input("fill_constant", "X");
auto* shape1 = pattern->NewNode(shape1_repr())->assert_is_op("shape");
auto* shape1_out = pattern->NewNode(shape1_out_repr())
->assert_is_op_output("shape", "Out")
->assert_is_op_input("slice", "X");
auto* slice1 = pattern->NewNode(slice1_repr())->assert_is_op("slice");
auto* slice1_out = pattern->NewNode(slice1_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_op_input("cast", "X");
auto* cast0 =
pattern->NewNode(cast0_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::INT32) &&
out_dtype == static_cast<int>(proto::VarType::INT64);
});
auto* cast0_out = pattern->NewNode(cast0_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("elementwise_add", "Y");
auto* elementwise_add_in0 = pattern->NewNode(elementwise_add_in0_repr())
->assert_is_op_input("elementwise_add", "X");
auto* elementwise_add =
pattern->NewNode(elementwise_add_repr())->assert_is_op("elementwise_add");
auto* elementwise_add_out =
pattern->NewNode(elementwise_add_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("scale", "X");
auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale");
auto* scale_out = pattern->NewNode(scale_out_repr())
->assert_is_op_output("scale", "Out")
->assert_is_op_input("cast", "X");
auto* cast1 =
pattern->NewNode(cast1_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::INT64) &&
out_dtype == static_cast<int>(proto::VarType::INT32);
});
auto* cast1_out = pattern->NewNode(cast1_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("fill_constant", "Y")
->assert_has_n_outputs(1);
auto* fill_constant =
pattern->NewNode(fill_constant_repr())->assert_is_op("fill_constant");
shape0->LinksFrom({shape_in}).LinksTo({shape0_out});
slice0->LinksFrom({shape0_out}).LinksTo({slice0_out});
shape1->LinksFrom({shape_in}).LinksTo({shape1_out});
slice1->LinksFrom({shape1_out}).LinksTo({slice1_out});
cast0->LinksFrom({slice1_out}).LinksTo({cast0_out});
elementwise_add->LinksFrom({elementwise_add_in0, cast0_out})
.LinksTo({elementwise_add_out});
scale->LinksFrom({elementwise_add_out}).LinksTo({scale_out});
cast1->LinksFrom({scale_out}).LinksTo({cast1_out});
fill_constant->LinksFrom({slice0_out, cast1_out});
}
} // namespace patterns
int XpuDeleteCastOpPass::ApplyCastCacheKVInitializationPass(
ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::CastCacheKVInitializationPattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastCacheKVInitializationPass fuse";
GET_IR_NODE_FROM_SUBGRAPH(shape_in, shape_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(shape0_out, shape0_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(slice0_out, slice0_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(shape1_out, shape1_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(slice1_out, slice1_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise_add_in0, elementwise_add_in0, pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise_add_out, elementwise_add_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast1_out, cast1_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(shape0, shape0, pattern);
GET_IR_NODE_FROM_SUBGRAPH(slice0, slice0, pattern);
GET_IR_NODE_FROM_SUBGRAPH(shape1, shape1, pattern);
GET_IR_NODE_FROM_SUBGRAPH(slice1, slice1, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale, scale, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast1, cast1, pattern);
GET_IR_NODE_FROM_SUBGRAPH(fill_constant, fill_constant, pattern);
slice0->Op()->RenameInput(shape0_out->Name(), shape1_out->Name());
IR_NODE_UNLINK(shape0_out, slice0);
IR_NODE_LINK_TO(shape1_out, slice0);
cast1->Op()->RenameInput(scale_out->Name(), elementwise_add_in0->Name());
elementwise_add->Op()->RenameInput(elementwise_add_in0->Name(),
cast1_out->Name());
IR_NODE_UNLINK(scale_out, cast1);
IR_NODE_UNLINK(elementwise_add_in0, elementwise_add);
IR_NODE_LINK_TO(elementwise_add_in0, cast1);
fill_constant->Op()->RenameInput(cast1_out->Name(), scale_out->Name());
IR_NODE_UNLINK(cast1_out, fill_constant);
IR_NODE_LINK_TO(cast1_out, elementwise_add);
IR_NODE_LINK_TO(scale_out, fill_constant);
elementwise_add->Op()->RenameInput(cast0_out->Name(), slice1_out->Name());
IR_NODE_UNLINK(slice1_out, cast0);
IR_NODE_UNLINK(cast0_out, elementwise_add);
IR_NODE_LINK_TO(slice1_out, elementwise_add);
std::unordered_set<const Node*> delete_nodes{
shape1, shape0_out, cast0, cast0_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void XpuDeleteCastOpPass::ApplyImpl(ir::Graph* graph) const { void XpuDeleteCastOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
...@@ -242,6 +416,16 @@ void XpuDeleteCastOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -242,6 +416,16 @@ void XpuDeleteCastOpPass::ApplyImpl(ir::Graph* graph) const {
LOG(INFO) << "--- delete " << found_subgraph_count LOG(INFO) << "--- delete " << found_subgraph_count
<< " cast_layer_norm_cast subgraph"; << " cast_layer_norm_cast subgraph";
} }
found_subgraph_count = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
found_subgraph_count +=
ApplyCastCacheKVInitializationPass(graph->GetSubGraph(i));
}
if (found_subgraph_count > 0) {
LOG(INFO) << "--- delete " << found_subgraph_count
<< " cast_cachekv_initialization_pattern subgraph";
}
} }
} // namespace ir } // namespace ir
......
...@@ -62,6 +62,57 @@ class XpuDeleteCastOpPass : public FusePassBase { ...@@ -62,6 +62,57 @@ class XpuDeleteCastOpPass : public FusePassBase {
*/ */
int ApplyCastLayerNormPass(ir::Graph* graph) const; int ApplyCastLayerNormPass(ir::Graph* graph) const;
/*
------------------------------------------------------
sub block:
x
/ \
/ \
/ \
shape shape
| |
| slice
slice |
| (max_dec_len) cast
| \ |
| elementwise_add
| |
| scale
| |
| cast
| |
\ /
\ /
\ /
fill
------------------------------------------------------
After the pass is applied:
x
|
|
shape
/ \
/ \
/ \
slice slice
| (max_dec_len) |
| \ |
| cast |
| \ |
| elementwise_add
| |
| |
| scale
| |
\ /
\ /
\ /
fill
*/
int ApplyCastCacheKVInitializationPass(ir::Graph* graph) const;
const std::string name_scope_{"xpu_delete_cast_op_pass"}; const std::string name_scope_{"xpu_delete_cast_op_pass"};
}; };
......
...@@ -112,6 +112,96 @@ TEST(ApplyCastLayerNormPass, basic) { ...@@ -112,6 +112,96 @@ TEST(ApplyCastLayerNormPass, basic) {
cast_num_in_graph)); cast_num_in_graph));
} }
TEST(ApplyCastCacheKVInitializationPass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
auto* shape_in =
Data(block, "shape_in", {64, 128}, false, proto::VarType::INT64);
auto* shape0_out =
Data(block, "shape0_out", {2}, false, proto::VarType::INT32);
auto* shape1_out =
Data(block, "shape1_out", {2}, false, proto::VarType::INT32);
auto* slice0_out =
Data(block, "slice0_out", {1}, false, proto::VarType::INT32);
auto* slice1_out =
Data(block, "slice1_out", {1}, false, proto::VarType::INT32);
auto* elementwise_add_in0 =
Data(block, "elementwise_add_in0", {1}, false, proto::VarType::INT64);
auto* elementwise_add_out =
Data(block, "elementwise_add_out", {1}, false, proto::VarType::INT64);
auto* scale_out = Data(block, "scale_out", {1}, false, proto::VarType::INT64);
OpDesc* shape0 = block->AppendOp();
shape0->SetType("shape");
shape0->SetInput("X", {shape_in->Name()});
shape0->SetOutput("Out", {shape0_out->Name()});
OpDesc* shape1 = block->AppendOp();
shape1->SetType("shape");
shape1->SetInput("X", {shape_in->Name()});
shape1->SetOutput("Out", {shape1_out->Name()});
OpDesc* slice0 = block->AppendOp();
slice0->SetType("slice");
slice0->SetInput("X", {shape0_out->Name()});
slice0->SetOutput("Out", {slice0_out->Name()});
OpDesc* slice1 = block->AppendOp();
slice1->SetType("slice");
slice1->SetInput("X", {shape1_out->Name()});
slice1->SetOutput("Out", {slice1_out->Name()});
auto cast0_out = AddCast(block,
slice1_out,
static_cast<int>(proto::VarType::INT32),
static_cast<int>(proto::VarType::INT64));
OpDesc* elementwise_add = block->AppendOp();
elementwise_add->SetType("elementwise_add");
elementwise_add->SetInput("X", {elementwise_add_in0->Name()});
elementwise_add->SetInput("Y", {cast0_out->Name()});
elementwise_add->SetOutput("Out", {elementwise_add_out->Name()});
OpDesc* scale = block->AppendOp();
scale->SetType("scale");
scale->SetInput("X", {elementwise_add_out->Name()});
scale->SetOutput("Out", {scale_out->Name()});
scale->SetAttr("scale", 1.0f);
scale->SetAttr("bias", 64.0f);
auto* cast1_out = AddCast(block,
scale_out,
static_cast<int>(proto::VarType::INT64),
static_cast<int>(proto::VarType::INT32));
OpDesc* fill_constant = block->AppendOp();
fill_constant->SetType("fill_constant");
fill_constant->SetInput("X", {slice0_out->Name()});
fill_constant->SetInput("Y", {cast1_out->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("xpu_delete_cast_op_pass");
pass->Apply(graph.get());
int shape_num_in_graph = GetOpNum(graph->GetSubGraph(0), "shape");
PADDLE_ENFORCE_EQ(
GetOpNum(graph->GetSubGraph(0), "shape"),
1,
platform::errors::PreconditionNotMet("graph should have 1 shape after "
"xpu_delete_cast_op_pass, "
"but actually has %d.",
shape_num_in_graph));
int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast");
PADDLE_ENFORCE_EQ(
GetOpNum(graph->GetSubGraph(0), "cast"),
1,
platform::errors::PreconditionNotMet("graph should have 1 cast after "
"xpu_delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph));
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册