From 610a47ddf526ec121a8987591fe8e28d17af98b0 Mon Sep 17 00:00:00 2001 From: xinxinZi <71006692+fengxin-hello@users.noreply.github.com> Date: Tue, 27 Jun 2023 13:33:23 +0800 Subject: [PATCH] add xpu_optimize_cachekv_initialization_pass (#54809) --- .../ir/xpu/xpu_delete_cast_op_pass.cc | 184 ++++++++++++++++++ .../ir/xpu/xpu_delete_cast_op_pass.h | 51 +++++ .../ir/xpu/xpu_delete_cast_op_pass_test.cc | 90 +++++++++ 3 files changed, 325 insertions(+) diff --git a/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.cc b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.cc index f310d5b105b..6f7ff2264ac 100644 --- a/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.cc +++ b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.cc @@ -214,6 +214,180 @@ int XpuDeleteCastOpPass::ApplyCastLayerNormPass(ir::Graph* graph) const { return found_subgraph_count; } +namespace patterns { +struct CastCacheKVInitializationPattern : public PatternBase { + CastCacheKVInitializationPattern(PDPattern* pattern, + const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(shape0); + PATTERN_DECL_NODE(shape1); + PATTERN_DECL_NODE(slice0); + PATTERN_DECL_NODE(slice1); + PATTERN_DECL_NODE(cast0); + PATTERN_DECL_NODE(elementwise_add); + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(cast1); + PATTERN_DECL_NODE(fill_constant); + + // declare variable node's name + + PATTERN_DECL_NODE(shape_in); + PATTERN_DECL_NODE(shape0_out); + PATTERN_DECL_NODE(slice0_out); + PATTERN_DECL_NODE(shape1_out); + PATTERN_DECL_NODE(slice1_out); + PATTERN_DECL_NODE(cast0_out); + PATTERN_DECL_NODE(elementwise_add_in0); + PATTERN_DECL_NODE(elementwise_add_out); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(cast1_out); +}; + +CastCacheKVInitializationPattern::CastCacheKVInitializationPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* shape_in = + pattern->NewNode(shape_in_repr())->assert_is_op_input("shape", "X"); + + auto* shape0 = pattern->NewNode(shape0_repr())->assert_is_op("shape"); + auto* shape0_out = pattern->NewNode(shape0_out_repr()) + ->assert_is_op_output("shape", "Out") + ->assert_is_op_input("slice", "X"); + auto* slice0 = pattern->NewNode(slice0_repr())->assert_is_op("slice"); + auto* slice0_out = pattern->NewNode(slice0_out_repr()) + ->assert_is_op_output("slice", "Out") + ->assert_is_op_input("fill_constant", "X"); + + auto* shape1 = pattern->NewNode(shape1_repr())->assert_is_op("shape"); + auto* shape1_out = pattern->NewNode(shape1_out_repr()) + ->assert_is_op_output("shape", "Out") + ->assert_is_op_input("slice", "X"); + auto* slice1 = pattern->NewNode(slice1_repr())->assert_is_op("slice"); + auto* slice1_out = pattern->NewNode(slice1_out_repr()) + ->assert_is_op_output("slice", "Out") + ->assert_is_op_input("cast", "X"); + auto* cast0 = + pattern->NewNode(cast0_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == static_cast(proto::VarType::INT32) && + out_dtype == static_cast(proto::VarType::INT64); + }); + auto* cast0_out = pattern->NewNode(cast0_out_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("elementwise_add", "Y"); + auto* elementwise_add_in0 = pattern->NewNode(elementwise_add_in0_repr()) + ->assert_is_op_input("elementwise_add", "X"); + auto* elementwise_add = + pattern->NewNode(elementwise_add_repr())->assert_is_op("elementwise_add"); + auto* elementwise_add_out = + pattern->NewNode(elementwise_add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("scale", "X"); + auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); + + auto* scale_out = pattern->NewNode(scale_out_repr()) + ->assert_is_op_output("scale", "Out") + ->assert_is_op_input("cast", "X"); + + auto* cast1 = + pattern->NewNode(cast1_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == static_cast(proto::VarType::INT64) && + out_dtype == static_cast(proto::VarType::INT32); + }); + auto* cast1_out = pattern->NewNode(cast1_out_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("fill_constant", "Y") + ->assert_has_n_outputs(1); + auto* fill_constant = + pattern->NewNode(fill_constant_repr())->assert_is_op("fill_constant"); + + shape0->LinksFrom({shape_in}).LinksTo({shape0_out}); + slice0->LinksFrom({shape0_out}).LinksTo({slice0_out}); + shape1->LinksFrom({shape_in}).LinksTo({shape1_out}); + slice1->LinksFrom({shape1_out}).LinksTo({slice1_out}); + cast0->LinksFrom({slice1_out}).LinksTo({cast0_out}); + elementwise_add->LinksFrom({elementwise_add_in0, cast0_out}) + .LinksTo({elementwise_add_out}); + scale->LinksFrom({elementwise_add_out}).LinksTo({scale_out}); + cast1->LinksFrom({scale_out}).LinksTo({cast1_out}); + fill_constant->LinksFrom({slice0_out, cast1_out}); +} +} // namespace patterns + +int XpuDeleteCastOpPass::ApplyCastCacheKVInitializationPass( + ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::CastCacheKVInitializationPattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyCastCacheKVInitializationPass fuse"; + GET_IR_NODE_FROM_SUBGRAPH(shape_in, shape_in, pattern); + GET_IR_NODE_FROM_SUBGRAPH(shape0_out, shape0_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(slice0_out, slice0_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(shape1_out, shape1_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(slice1_out, slice1_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise_add_in0, elementwise_add_in0, pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise_add_out, elementwise_add_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast1_out, cast1_out, pattern); + + GET_IR_NODE_FROM_SUBGRAPH(shape0, shape0, pattern); + GET_IR_NODE_FROM_SUBGRAPH(slice0, slice0, pattern); + GET_IR_NODE_FROM_SUBGRAPH(shape1, shape1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(slice1, slice1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale, scale, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast1, cast1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(fill_constant, fill_constant, pattern); + + slice0->Op()->RenameInput(shape0_out->Name(), shape1_out->Name()); + IR_NODE_UNLINK(shape0_out, slice0); + IR_NODE_LINK_TO(shape1_out, slice0); + + cast1->Op()->RenameInput(scale_out->Name(), elementwise_add_in0->Name()); + elementwise_add->Op()->RenameInput(elementwise_add_in0->Name(), + cast1_out->Name()); + IR_NODE_UNLINK(scale_out, cast1); + IR_NODE_UNLINK(elementwise_add_in0, elementwise_add); + IR_NODE_LINK_TO(elementwise_add_in0, cast1); + + fill_constant->Op()->RenameInput(cast1_out->Name(), scale_out->Name()); + IR_NODE_UNLINK(cast1_out, fill_constant); + IR_NODE_LINK_TO(cast1_out, elementwise_add); + IR_NODE_LINK_TO(scale_out, fill_constant); + + elementwise_add->Op()->RenameInput(cast0_out->Name(), slice1_out->Name()); + IR_NODE_UNLINK(slice1_out, cast0); + IR_NODE_UNLINK(cast0_out, elementwise_add); + IR_NODE_LINK_TO(slice1_out, elementwise_add); + + std::unordered_set delete_nodes{ + shape1, shape0_out, cast0, cast0_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + void XpuDeleteCastOpPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); @@ -242,6 +416,16 @@ void XpuDeleteCastOpPass::ApplyImpl(ir::Graph* graph) const { LOG(INFO) << "--- delete " << found_subgraph_count << " cast_layer_norm_cast subgraph"; } + + found_subgraph_count = 0; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + found_subgraph_count += + ApplyCastCacheKVInitializationPass(graph->GetSubGraph(i)); + } + if (found_subgraph_count > 0) { + LOG(INFO) << "--- delete " << found_subgraph_count + << " cast_cachekv_initialization_pattern subgraph"; + } } } // namespace ir diff --git a/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.h b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.h index d0556e8b0bf..06309dec31f 100755 --- a/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.h +++ b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.h @@ -62,6 +62,57 @@ class XpuDeleteCastOpPass : public FusePassBase { */ int ApplyCastLayerNormPass(ir::Graph* graph) const; + /* + ------------------------------------------------------ + sub block: + x + / \ + / \ + / \ + shape shape + | | + | slice + slice | + | (max_dec_len) cast + | \ | + | elementwise_add + | | + | scale + | | + | cast + | | + \ / + \ / + \ / + fill + + ------------------------------------------------------ + After the pass is applied: + x + | + | + shape + / \ + / \ + / \ + slice slice + | (max_dec_len) | + | \ | + | cast | + | \ | + | elementwise_add + | | + | | + | scale + | | + \ / + \ / + \ / + fill + + */ + int ApplyCastCacheKVInitializationPass(ir::Graph* graph) const; + const std::string name_scope_{"xpu_delete_cast_op_pass"}; }; diff --git a/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass_test.cc b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass_test.cc index f8682bdca9f..9ba89d0af28 100644 --- a/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass_test.cc +++ b/paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass_test.cc @@ -112,6 +112,96 @@ TEST(ApplyCastLayerNormPass, basic) { cast_num_in_graph)); } +TEST(ApplyCastCacheKVInitializationPass, basic) { + paddle::framework::ProgramDesc program; + auto* block = program.MutableBlock(0); + auto* shape_in = + Data(block, "shape_in", {64, 128}, false, proto::VarType::INT64); + auto* shape0_out = + Data(block, "shape0_out", {2}, false, proto::VarType::INT32); + auto* shape1_out = + Data(block, "shape1_out", {2}, false, proto::VarType::INT32); + auto* slice0_out = + Data(block, "slice0_out", {1}, false, proto::VarType::INT32); + auto* slice1_out = + Data(block, "slice1_out", {1}, false, proto::VarType::INT32); + auto* elementwise_add_in0 = + Data(block, "elementwise_add_in0", {1}, false, proto::VarType::INT64); + auto* elementwise_add_out = + Data(block, "elementwise_add_out", {1}, false, proto::VarType::INT64); + auto* scale_out = Data(block, "scale_out", {1}, false, proto::VarType::INT64); + + OpDesc* shape0 = block->AppendOp(); + shape0->SetType("shape"); + shape0->SetInput("X", {shape_in->Name()}); + shape0->SetOutput("Out", {shape0_out->Name()}); + + OpDesc* shape1 = block->AppendOp(); + shape1->SetType("shape"); + shape1->SetInput("X", {shape_in->Name()}); + shape1->SetOutput("Out", {shape1_out->Name()}); + + OpDesc* slice0 = block->AppendOp(); + slice0->SetType("slice"); + slice0->SetInput("X", {shape0_out->Name()}); + slice0->SetOutput("Out", {slice0_out->Name()}); + + OpDesc* slice1 = block->AppendOp(); + slice1->SetType("slice"); + slice1->SetInput("X", {shape1_out->Name()}); + slice1->SetOutput("Out", {slice1_out->Name()}); + + auto cast0_out = AddCast(block, + slice1_out, + static_cast(proto::VarType::INT32), + static_cast(proto::VarType::INT64)); + + OpDesc* elementwise_add = block->AppendOp(); + elementwise_add->SetType("elementwise_add"); + elementwise_add->SetInput("X", {elementwise_add_in0->Name()}); + elementwise_add->SetInput("Y", {cast0_out->Name()}); + elementwise_add->SetOutput("Out", {elementwise_add_out->Name()}); + + OpDesc* scale = block->AppendOp(); + scale->SetType("scale"); + scale->SetInput("X", {elementwise_add_out->Name()}); + scale->SetOutput("Out", {scale_out->Name()}); + scale->SetAttr("scale", 1.0f); + scale->SetAttr("bias", 64.0f); + + auto* cast1_out = AddCast(block, + scale_out, + static_cast(proto::VarType::INT64), + static_cast(proto::VarType::INT32)); + + OpDesc* fill_constant = block->AppendOp(); + fill_constant->SetType("fill_constant"); + fill_constant->SetInput("X", {slice0_out->Name()}); + fill_constant->SetInput("Y", {cast1_out->Name()}); + + std::unique_ptr graph(new ir::Graph(program)); + auto scope = new Scope(); + graph->Set("__param_scope__", scope); + auto pass = PassRegistry::Instance().Get("xpu_delete_cast_op_pass"); + pass->Apply(graph.get()); + int shape_num_in_graph = GetOpNum(graph->GetSubGraph(0), "shape"); + PADDLE_ENFORCE_EQ( + GetOpNum(graph->GetSubGraph(0), "shape"), + 1, + platform::errors::PreconditionNotMet("graph should have 1 shape after " + "xpu_delete_cast_op_pass, " + "but actually has %d.", + shape_num_in_graph)); + int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast"); + PADDLE_ENFORCE_EQ( + GetOpNum(graph->GetSubGraph(0), "cast"), + 1, + platform::errors::PreconditionNotMet("graph should have 1 cast after " + "xpu_delete_cast_op_pass, " + "but actually has %d.", + cast_num_in_graph)); +} + } // namespace ir } // namespace framework } // namespace paddle -- GitLab