未验证 提交 720b14e3 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] optimize graph if beam_size=1 (#51732)

上级 2922aa67
...@@ -233,6 +233,7 @@ if(WITH_XPU) ...@@ -233,6 +233,7 @@ if(WITH_XPU)
pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu DEPS pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(one_beam_size_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(delete_isolated_node_pass inference DIR xpu DEPS pass_library(delete_isolated_node_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
...@@ -499,4 +500,8 @@ if(WITH_XPU) ...@@ -499,4 +500,8 @@ if(WITH_XPU)
test_fused_multi_transformer_xpu_quant_pass test_fused_multi_transformer_xpu_quant_pass
SRCS xpu/fused_multi_transformer_xpu_quant_pass_tester.cc SRCS xpu/fused_multi_transformer_xpu_quant_pass_tester.cc
DEPS fused_multi_transformer_xpu_quant_pass) DEPS fused_multi_transformer_xpu_quant_pass)
cc_test(
test_one_beam_size_fuse_pass
SRCS xpu/one_beam_size_fuse_pass_test.cc
DEPS one_beam_size_fuse_pass)
endif() endif()
...@@ -49,9 +49,11 @@ static const std::vector<std::string> support_subgraph_passes = { ...@@ -49,9 +49,11 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass", "fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass", "delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass", "delete_weight_dequant_linear_op_pass",
"one_beam_size_fuse_pass",
"fused_multi_transformer_xpu_quant_pass", "fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass", "fc_xpu_fuse_pass",
"delete_op_device_pass"}; "delete_op_device_pass",
};
Graph *Pass::Apply(Graph *graph) const { Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph"; VLOG(10) << "start to apply pass " << Type() << " to graph";
......
...@@ -33,6 +33,8 @@ struct Layers { ...@@ -33,6 +33,8 @@ struct Layers {
public: public:
const ProgramDesc& main_program() { return program_; } const ProgramDesc& main_program() { return program_; }
BlockDesc* Block() { return program_.MutableBlock(0); }
VarDesc* data(std::string name, VarDesc* data(std::string name,
std::vector<int64_t> shape = {}, std::vector<int64_t> shape = {},
bool is_persistable = false, bool is_persistable = false,
...@@ -132,7 +134,7 @@ struct Layers { ...@@ -132,7 +134,7 @@ struct Layers {
return out; return out;
} }
VarDesc* unsqueeze2(VarDesc* x, const std::vector<int> axes) { VarDesc* unsqueeze2(VarDesc* x, const std::vector<int> axes = {-1}) {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp(); OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("unsqueeze2"); op->SetType("unsqueeze2");
...@@ -294,6 +296,13 @@ struct Layers { ...@@ -294,6 +296,13 @@ struct Layers {
return binary_op("elementwise_mul", x, y, out, attrs); return binary_op("elementwise_mul", x, y, out, attrs);
} }
VarDesc* elementwise_div(VarDesc* x,
VarDesc* y,
VarDesc* out = nullptr,
const AttributeMap* attrs = nullptr) {
return binary_op("elementwise_div", x, y, out, attrs);
}
VarDesc* dropout(VarDesc* x, VarDesc* dropout(VarDesc* x,
float dropout_prob, float dropout_prob,
std::string dropout_implementation) { std::string dropout_implementation) {
...@@ -458,7 +467,10 @@ struct Layers { ...@@ -458,7 +467,10 @@ struct Layers {
return out; return out;
} }
VarDesc* scale(VarDesc* x, float scale, float bias, bool bias_after) { VarDesc* scale(VarDesc* x,
float scale = 1.,
float bias = 0.,
bool bias_after = true) {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp(); OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("scale"); op->SetType("scale");
...@@ -713,6 +725,88 @@ struct Layers { ...@@ -713,6 +725,88 @@ struct Layers {
} }
} }
VarDesc* cast(VarDesc* input, int in_dtype = 5, int out_dtype = 5) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("cast");
op->SetInput("X", {input->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("in_dtype", in_dtype);
op->SetAttr("out_dtype", out_dtype);
return out;
}
VarDesc* range(VarDesc* start, VarDesc* end, VarDesc* step) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("range");
op->SetInput("Start", {start->Name()});
op->SetInput("End", {end->Name()});
op->SetInput("Step", {step->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* flatten_contiguous_range(VarDesc* input) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("flatten_contiguous_range");
op->SetInput("X", {input->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
std::vector<VarDesc*> beam_search(VarDesc* ids,
VarDesc* scores,
VarDesc* pre_ids,
VarDesc* pre_scores,
int beam_size = 1) {
VarDesc* parent_idx = lod_tensor(unique_name());
VarDesc* selected_ids = lod_tensor(unique_name());
VarDesc* selected_scores = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("beam_search");
op->SetInput("ids", {ids->Name()});
op->SetInput("scores", {scores->Name()});
op->SetInput("pre_ids", {pre_ids->Name()});
op->SetInput("pre_scores", {pre_scores->Name()});
op->SetOutput("parent_idx", {parent_idx->Name()});
op->SetOutput("selected_ids", {selected_ids->Name()});
op->SetOutput("selected_scores", {selected_scores->Name()});
op->SetAttr("beam_size", 1);
return {parent_idx, selected_ids, selected_scores};
}
VarDesc* lod_reset(VarDesc* x, VarDesc* y) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("lod_reset");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* write_to_array(std::vector<VarDesc*> x, VarDesc* i) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("write_to_array");
std::vector<std::string> x_names;
for (auto k : x) {
x_names.push_back(k->Name());
}
op->SetInput("X", x_names);
op->SetInput("I", {i->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* is_empty(VarDesc* input) { return unary_op("is_empty", input); }
VarDesc* logical_not(VarDesc* input) {
return unary_op("logical_not", input);
}
private: private:
VarDesc* lod_tensor(std::string name, VarDesc* lod_tensor(std::string name,
std::vector<int64_t> shape = {}, std::vector<int64_t> shape = {},
...@@ -927,10 +1021,11 @@ static std::vector<ir::Node*> GetOpNodes(const std::unique_ptr<Graph>& graph, ...@@ -927,10 +1021,11 @@ static std::vector<ir::Node*> GetOpNodes(const std::unique_ptr<Graph>& graph,
} }
static int GetNumOpNodes(const std::unique_ptr<Graph>& graph, static int GetNumOpNodes(const std::unique_ptr<Graph>& graph,
std::string op_type) { std::string op_type = "") {
int num_nodes = 0; int num_nodes = 0;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op() && node->Op()->Type() == op_type) { if (node->IsOp() && node->Op() &&
(node->Op()->Type() == op_type || op_type.empty())) {
num_nodes++; num_nodes++;
} }
} }
......
...@@ -39,7 +39,6 @@ namespace patterns { ...@@ -39,7 +39,6 @@ namespace patterns {
struct FusedMultiTransformerPattern : public PatternBase { struct FusedMultiTransformerPattern : public PatternBase {
FusedMultiTransformerPattern(PDPattern* pattern, FusedMultiTransformerPattern(PDPattern* pattern,
const std::string& name_scope, const std::string& name_scope,
bool with_cache_kv,
bool with_pre_caches, bool with_pre_caches,
bool with_rotary_pos_emb, bool with_rotary_pos_emb,
bool with_time_step, bool with_time_step,
...@@ -54,7 +53,6 @@ struct FusedMultiTransformerPattern : public PatternBase { ...@@ -54,7 +53,6 @@ struct FusedMultiTransformerPattern : public PatternBase {
PATTERN_DECL_NODE(ln_bias); PATTERN_DECL_NODE(ln_bias);
PATTERN_DECL_NODE(qkv_w); PATTERN_DECL_NODE(qkv_w);
PATTERN_DECL_NODE(qkv_bias); PATTERN_DECL_NODE(qkv_bias);
PATTERN_DECL_NODE(cache_kv);
PATTERN_DECL_NODE(pre_caches); PATTERN_DECL_NODE(pre_caches);
PATTERN_DECL_NODE(rotary_pos_emb); PATTERN_DECL_NODE(rotary_pos_emb);
PATTERN_DECL_NODE(time_step); PATTERN_DECL_NODE(time_step);
...@@ -68,11 +66,9 @@ struct FusedMultiTransformerPattern : public PatternBase { ...@@ -68,11 +66,9 @@ struct FusedMultiTransformerPattern : public PatternBase {
PATTERN_DECL_NODE(ffn1_bias); PATTERN_DECL_NODE(ffn1_bias);
PATTERN_DECL_NODE(ffn2_w); PATTERN_DECL_NODE(ffn2_w);
PATTERN_DECL_NODE(ffn2_bias); PATTERN_DECL_NODE(ffn2_bias);
PATTERN_DECL_NODE(cache_kv_out);
PATTERN_DECL_NODE(out); PATTERN_DECL_NODE(out);
private: private:
bool with_cache_kv_{false};
bool with_pre_caches_{false}; bool with_pre_caches_{false};
bool with_rotary_pos_emb_{false}; bool with_rotary_pos_emb_{false};
bool with_time_step_{false}; bool with_time_step_{false};
...@@ -83,14 +79,12 @@ struct FusedMultiTransformerPattern : public PatternBase { ...@@ -83,14 +79,12 @@ struct FusedMultiTransformerPattern : public PatternBase {
FusedMultiTransformerPattern::FusedMultiTransformerPattern( FusedMultiTransformerPattern::FusedMultiTransformerPattern(
PDPattern* pattern, PDPattern* pattern,
const std::string& name_scope, const std::string& name_scope,
bool with_cache_kv,
bool with_pre_caches, bool with_pre_caches,
bool with_rotary_pos_emb, bool with_rotary_pos_emb,
bool with_time_step, bool with_time_step,
bool with_seq_lengths, bool with_seq_lengths,
bool with_src_mask) bool with_src_mask)
: PatternBase(pattern, name_scope, name_scope), : PatternBase(pattern, name_scope, name_scope),
with_cache_kv_(with_cache_kv),
with_pre_caches_(with_pre_caches), with_pre_caches_(with_pre_caches),
with_rotary_pos_emb_(with_rotary_pos_emb), with_rotary_pos_emb_(with_rotary_pos_emb),
with_time_step_(with_time_step), with_time_step_(with_time_step),
...@@ -102,9 +96,6 @@ FusedMultiTransformerPattern::FusedMultiTransformerPattern( ...@@ -102,9 +96,6 @@ FusedMultiTransformerPattern::FusedMultiTransformerPattern(
auto* x = pattern->NewNode(x_repr()) auto* x = pattern->NewNode(x_repr())
->assert_is_op_input(op_type, "X") ->assert_is_op_input(op_type, "X")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* cache_kv_out = pattern->NewNode(cache_kv_out_repr())
->assert_is_op_output(op_type, "CacheKVOut")
->assert_var_not_persistable();
auto* out = pattern->NewNode(out_repr()) auto* out = pattern->NewNode(out_repr())
->assert_is_op_output(op_type, "Out") ->assert_is_op_output(op_type, "Out")
->assert_var_not_persistable(); ->assert_var_not_persistable();
...@@ -195,21 +186,14 @@ FusedMultiTransformerPattern::FusedMultiTransformerPattern( ...@@ -195,21 +186,14 @@ FusedMultiTransformerPattern::FusedMultiTransformerPattern(
ffn1_bias, ffn1_bias,
ffn2_w, ffn2_w,
ffn2_bias}; ffn2_bias};
std::vector<PDNode*> output_vars{cache_kv_out, out}; std::vector<PDNode*> output_vars{out};
// optional node // optional node
PDNode* cache_kv = nullptr;
PDNode* pre_caches = nullptr; PDNode* pre_caches = nullptr;
PDNode* rotary_pos_emb = nullptr; PDNode* rotary_pos_emb = nullptr;
PDNode* time_step = nullptr; PDNode* time_step = nullptr;
PDNode* seq_lengths = nullptr; PDNode* seq_lengths = nullptr;
PDNode* src_mask = nullptr; PDNode* src_mask = nullptr;
if (with_cache_kv_) {
cache_kv = pattern->NewNode(cache_kv_repr())
->assert_is_op_input(op_type, "CacheKV")
->assert_var_not_persistable();
input_vars.push_back(cache_kv);
}
if (with_pre_caches_) { if (with_pre_caches_) {
pre_caches = pattern->NewNode(pre_caches_repr()) pre_caches = pattern->NewNode(pre_caches_repr())
->assert_is_op_input(op_type, "PreCaches") ->assert_is_op_input(op_type, "PreCaches")
...@@ -256,7 +240,6 @@ class FusedMultiTransformerXPUQuantPass : public FusePassBase { ...@@ -256,7 +240,6 @@ class FusedMultiTransformerXPUQuantPass : public FusePassBase {
private: private:
int ApplyImpl(ir::Graph* graph, int ApplyImpl(ir::Graph* graph,
bool with_cache_kv,
bool with_pre_caches, bool with_pre_caches,
bool with_rotary_pos_emb, bool with_rotary_pos_emb,
bool with_time_step, bool with_time_step,
...@@ -275,13 +258,12 @@ void FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph) const { ...@@ -275,13 +258,12 @@ void FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph) const {
int found_subgraph_count = 0; int found_subgraph_count = 0;
for (bool with_time_step : {true, false}) { for (bool with_time_step : {true, false}) {
found_subgraph_count += found_subgraph_count +=
ApplyImpl(graph, true, false, false, with_time_step, false, true); ApplyImpl(graph, false, false, with_time_step, false, true);
} }
AddStatis(found_subgraph_count); AddStatis(found_subgraph_count);
} }
int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
bool with_cache_kv,
bool with_pre_caches, bool with_pre_caches,
bool with_rotary_pos_emb, bool with_rotary_pos_emb,
bool with_time_step, bool with_time_step,
...@@ -290,7 +272,6 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, ...@@ -290,7 +272,6 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::FusedMultiTransformerPattern pattern(gpd.mutable_pattern(), patterns::FusedMultiTransformerPattern pattern(gpd.mutable_pattern(),
name_scope_, name_scope_,
with_cache_kv,
with_pre_caches, with_pre_caches,
with_rotary_pos_emb, with_rotary_pos_emb,
with_time_step, with_time_step,
...@@ -307,7 +288,6 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, ...@@ -307,7 +288,6 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
GET_IR_NODE(ln_bias); GET_IR_NODE(ln_bias);
GET_IR_NODE(qkv_w); GET_IR_NODE(qkv_w);
GET_IR_NODE(qkv_bias); GET_IR_NODE(qkv_bias);
GET_IR_NODE(cache_kv);
GET_IR_NODE(pre_caches); GET_IR_NODE(pre_caches);
GET_IR_NODE(rotary_pos_emb); GET_IR_NODE(rotary_pos_emb);
GET_IR_NODE(time_step); GET_IR_NODE(time_step);
...@@ -321,7 +301,6 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, ...@@ -321,7 +301,6 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
GET_IR_NODE(ffn1_bias); GET_IR_NODE(ffn1_bias);
GET_IR_NODE(ffn2_w); GET_IR_NODE(ffn2_w);
GET_IR_NODE(ffn2_bias); GET_IR_NODE(ffn2_bias);
GET_IR_NODE(cache_kv_out);
GET_IR_NODE(out); GET_IR_NODE(out);
GET_IR_NODE(fused_mt); GET_IR_NODE(fused_mt);
auto* block = fused_mt->Op()->Block(); auto* block = fused_mt->Op()->Block();
...@@ -469,7 +448,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph, ...@@ -469,7 +448,7 @@ int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
fused_mt_xpu_op_desc->SetInput("ln_scale", name_caches.at("LnScale")); fused_mt_xpu_op_desc->SetInput("ln_scale", name_caches.at("LnScale"));
fused_mt_xpu_op_desc->SetInput("ln_bias", name_caches.at("LnBias")); fused_mt_xpu_op_desc->SetInput("ln_bias", name_caches.at("LnBias"));
fused_mt_xpu_op_desc->SetInput("qkv_bias", name_caches.at("QKVBias")); fused_mt_xpu_op_desc->SetInput("qkv_bias", name_caches.at("QKVBias"));
if (cache_kv) { if (name_caches.count("CacheKV") > 0) {
fused_mt_xpu_op_desc->SetInput("cache_kv", name_caches.at("CacheKV")); fused_mt_xpu_op_desc->SetInput("cache_kv", name_caches.at("CacheKV"));
} }
if (pre_caches) { if (pre_caches) {
......
...@@ -12,17 +12,6 @@ ...@@ -12,17 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h" #include "paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
......
...@@ -12,17 +12,6 @@ ...@@ -12,17 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct AssignPattern : public PatternBase {
AssignPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(assign);
// declare variable node's name
PATTERN_DECL_NODE(assign_out);
};
AssignPattern::AssignPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* assign =
pattern->NewNode(assign_repr())
->assert_is_op("assign")
->assert_more([&](Node* node) {
auto pre_op_nodes = node->inputs[0]->inputs;
return pre_op_nodes.size() == 1 &&
pre_op_nodes[0]->Op()->Type() == "fused_multi_transformer";
});
auto* assign_out =
pattern->NewNode(assign_out_repr())->assert_is_op_output("assign", "Out");
assign->LinksTo({assign_out});
}
struct ShapeAssociatedOpsPattern : public PatternBase {
ShapeAssociatedOpsPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(shape);
PATTERN_DECL_NODE(slice);
PATTERN_DECL_NODE(div);
PATTERN_DECL_NODE(cast_0);
PATTERN_DECL_NODE(cast_1);
PATTERN_DECL_NODE(scale_0);
PATTERN_DECL_NODE(cast_2);
PATTERN_DECL_NODE(range);
PATTERN_DECL_NODE(unsqueeze2);
PATTERN_DECL_NODE(scale_1);
PATTERN_DECL_NODE(add);
PATTERN_DECL_NODE(flatten_contiguous_range);
// declare variable node's name
PATTERN_DECL_NODE(shape_out);
PATTERN_DECL_NODE(slice_out);
PATTERN_DECL_NODE(div_out);
PATTERN_DECL_NODE(cast_0_out);
PATTERN_DECL_NODE(cast_1_out);
PATTERN_DECL_NODE(scale_0_out);
PATTERN_DECL_NODE(cast_2_out);
PATTERN_DECL_NODE(range_out);
PATTERN_DECL_NODE(unsqueeze2_out);
PATTERN_DECL_NODE(scale_1_out);
PATTERN_DECL_NODE(add_x);
PATTERN_DECL_NODE(add_out);
};
ShapeAssociatedOpsPattern::ShapeAssociatedOpsPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* shape = pattern->NewNode(shape_repr())->assert_is_op("shape");
auto* shape_out = pattern->NewNode(shape_out_repr())
->assert_is_op_output("shape", "Out")
->assert_is_op_input("slice", "Input");
auto* slice =
pattern->NewNode(slice_repr())
->assert_is_op("slice")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<std::vector<int>>("axes") ==
std::vector<int>{0} &&
op_desc->GetAttrIfExists<std::vector<int>>("starts") ==
std::vector<int>{0} &&
op_desc->GetAttrIfExists<std::vector<int>>("ends") ==
std::vector<int>{1};
});
auto* slice_out = pattern->NewNode(slice_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_op_input("elementwise_div", "X")
->assert_is_op_input("elementwise_div", "Y")
->assert_is_op_input("cast", "X")
->assert_is_op_input("scale", "X");
auto* div = pattern->NewNode(div_repr())->assert_is_op("elementwise_div");
auto* div_out = pattern->NewNode(div_out_repr())
->assert_is_op_output("elementwise_div", "Out")
->assert_is_op_input("cast", "X");
auto* cast_0 = pattern->NewNode(cast_0_repr())->assert_is_op("cast");
auto* cast_0_out = pattern->NewNode(cast_0_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("range", "Step");
auto* cast_1 = pattern->NewNode(cast_1_repr())->assert_is_op("cast");
auto* cast_1_out = pattern->NewNode(cast_1_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("range", "End");
auto* scale_0 = pattern->NewNode(scale_0_repr())->assert_is_op("scale");
auto* scale_0_out = pattern->NewNode(scale_0_out_repr())
->assert_is_op_output("scale", "Out")
->assert_is_op_input("cast", "X");
auto* cast_2 = pattern->NewNode(cast_2_repr())->assert_is_op("cast");
auto* cast_2_out = pattern->NewNode(cast_2_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("range", "Start");
auto* range = pattern->NewNode(range_repr())->assert_is_op("range");
auto* range_out = pattern->NewNode(range_out_repr())
->assert_is_op_output("range", "Out")
->assert_is_op_input("unsqueeze2", "X");
auto* unsqueeze2 =
pattern->NewNode(unsqueeze2_repr())->assert_is_op("unsqueeze2");
auto* unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr())
->assert_is_op_output("unsqueeze2", "Out")
->assert_is_op_input("scale", "X");
auto* scale_1 = pattern->NewNode(scale_1_repr())->assert_is_op("scale");
auto* scale_1_out = pattern->NewNode(scale_1_out_repr())
->assert_is_op_output("scale", "Out")
->assert_is_op_input("elementwise_add", "Y");
auto* add_x = pattern->NewNode(add_x_repr())
->assert_is_op_input("elementwise_add", "X");
auto* add = pattern->NewNode(add_repr())->assert_is_op("elementwise_add");
auto* add_out = pattern->NewNode(add_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("flatten_contiguous_range", "X");
auto* flatten_contiguous_range =
pattern->NewNode(flatten_contiguous_range_repr())
->assert_is_op("flatten_contiguous_range");
shape->LinksTo({shape_out});
slice->LinksFrom({shape_out}).LinksTo({slice_out});
div->LinksFrom({slice_out}).LinksTo({div_out});
cast_0->LinksFrom({div_out}).LinksTo({cast_0_out});
cast_1->LinksFrom({slice_out}).LinksTo({cast_1_out});
scale_0->LinksFrom({slice_out}).LinksTo({scale_0_out});
cast_2->LinksFrom({scale_0_out}).LinksTo({cast_2_out});
range->LinksFrom({cast_0_out, cast_1_out, cast_2_out}).LinksTo({range_out});
unsqueeze2->LinksFrom({range_out}).LinksTo({unsqueeze2_out});
scale_1->LinksFrom({unsqueeze2_out}).LinksTo({scale_1_out});
add->LinksFrom({scale_1_out, add_x}).LinksTo({add_out});
flatten_contiguous_range->LinksFrom({add_out});
}
struct BeamSearchAssociatedOpsPattern : public PatternBase {
BeamSearchAssociatedOpsPattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(lod_reset_0);
PATTERN_DECL_NODE(lod_reset_1);
PATTERN_DECL_NODE(beam_search);
PATTERN_DECL_NODE(write_to_array_0);
PATTERN_DECL_NODE(write_to_array_1);
PATTERN_DECL_NODE(is_empty);
PATTERN_DECL_NODE(logical_not);
PATTERN_DECL_NODE(cast);
// declare variable node's name
PATTERN_DECL_NODE(lod_reset_0_out);
PATTERN_DECL_NODE(lod_reset_1_out);
PATTERN_DECL_NODE(beam_search_parent_idx);
PATTERN_DECL_NODE(beam_search_selected_ids);
PATTERN_DECL_NODE(beam_search_selected_scores);
PATTERN_DECL_NODE(is_empty_out);
PATTERN_DECL_NODE(logical_not_out);
PATTERN_DECL_NODE(cast_out);
};
BeamSearchAssociatedOpsPattern::BeamSearchAssociatedOpsPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* lod_reset_0 =
pattern->NewNode(lod_reset_0_repr())->assert_is_op("lod_reset");
auto* lod_reset_0_out = pattern->NewNode(lod_reset_0_out_repr())
->assert_is_op_output("lod_reset", "Out")
->assert_is_op_input("beam_search", "ids");
auto* lod_reset_1 =
pattern->NewNode(lod_reset_1_repr())->assert_is_op("lod_reset");
auto* lod_reset_1_out = pattern->NewNode(lod_reset_1_out_repr())
->assert_is_op_output("lod_reset", "Out")
->assert_is_op_input("beam_search", "scores");
auto* beam_search =
pattern->NewNode(beam_search_repr())->assert_is_op("beam_search");
auto* beam_search_selected_ids =
pattern->NewNode(beam_search_selected_ids_repr())
->assert_is_op_output("beam_search", "selected_ids")
->assert_is_op_input("write_to_array", "X")
->assert_is_op_input("is_empty", "X");
auto* beam_search_selected_scores =
pattern->NewNode(beam_search_selected_scores_repr())
->assert_is_op_output("beam_search", "selected_scores")
->assert_is_op_input("write_to_array", "X");
auto* beam_search_parent_idx =
pattern->NewNode(beam_search_parent_idx_repr())
->assert_is_op_output("beam_search", "parent_idx")
->assert_is_op_input("cast", "X");
auto* write_to_array_0 =
pattern->NewNode(write_to_array_0_repr())->assert_is_op("write_to_array");
auto* write_to_array_1 =
pattern->NewNode(write_to_array_1_repr())->assert_is_op("write_to_array");
auto* is_empty = pattern->NewNode(is_empty_repr())->assert_is_op("is_empty");
auto* is_empty_out = pattern->NewNode(is_empty_out_repr())
->assert_is_op_output("is_empty", "Out")
->assert_is_op_input("logical_not", "X");
auto* logical_not =
pattern->NewNode(logical_not_repr())->assert_is_op("logical_not");
auto* logical_not_out = pattern->NewNode(logical_not_out_repr())
->assert_is_op_output("logical_not", "Out");
auto* cast = pattern->NewNode(cast_repr())->assert_is_op("cast");
auto* cast_out =
pattern->NewNode(cast_out_repr())->assert_is_op_output("cast", "Out");
lod_reset_0->LinksTo({lod_reset_0_out});
lod_reset_1->LinksTo({lod_reset_1_out});
beam_search->LinksFrom({lod_reset_0_out, lod_reset_1_out})
.LinksTo({beam_search_selected_ids,
beam_search_selected_scores,
beam_search_parent_idx});
write_to_array_0->LinksFrom({beam_search_selected_ids});
write_to_array_1->LinksFrom({beam_search_selected_scores});
is_empty->LinksFrom({beam_search_selected_ids}).LinksTo({is_empty_out});
logical_not->LinksFrom({is_empty_out}).LinksTo({logical_not_out});
cast->LinksFrom({beam_search_parent_idx}).LinksTo({cast_out});
}
} // namespace patterns
bool OnlyOneBeamSearchAndOneBeamSize(ir::Graph* graph) {
std::vector<Node*> beam_search_nodes;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "beam_search") {
beam_search_nodes.push_back(node);
}
}
return beam_search_nodes.size() == 1 &&
beam_search_nodes[0]->Op()->GetAttrIfExists<int>("beam_size") == 1;
}
Node* FindOpNodeByInputName(Graph* graph,
const std::string& op_type,
const std::string& arg_name,
const std::string& var_name) {
for (auto* node : graph->Nodes()) {
if (!node->IsOp() || node->Op()->Type() != op_type) continue;
auto inputs = node->Op()->Inputs();
if (inputs.count(arg_name) == 0) continue;
auto in_names = inputs.at(arg_name);
if (std::find(in_names.begin(), in_names.end(), var_name) == in_names.end())
continue;
return node;
}
return nullptr;
}
void OneBeamSizeFusePass::RemoveAssignGather(ir::Graph* graph) const {
// detect assign + gather
GraphPatternDetector gpd;
patterns::AssignPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle RemoveAssignGather";
GET_IR_NODE(assign);
GET_IR_NODE(assign_out);
// Assign_out may not link to gather, so we find gather by input name.
auto* gather =
FindOpNodeByInputName(graph, "gather", "X", assign_out->Name());
if (gather == nullptr) return;
// "assign_out" is used in multi blocks. "assign_out" should be reserved.
auto* assign_in = assign->inputs[0];
auto* fused_multi_transformer = assign_in->inputs[0];
fused_multi_transformer->Op()->Rename(assign_in->Name(),
assign_out->Name());
IR_NODE_LINK_TO(fused_multi_transformer, assign_out);
std::unordered_set<const Node*> delete_nodes{assign, assign_in, gather};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void OneBeamSizeFusePass::FoldShapeAssociatedOps(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::ShapeAssociatedOpsPattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FoldShapeAssociatedOps";
GET_IR_NODE(shape);
GET_IR_NODE(slice);
GET_IR_NODE(div);
GET_IR_NODE(cast_0);
GET_IR_NODE(cast_1);
GET_IR_NODE(scale_0);
GET_IR_NODE(cast_2);
GET_IR_NODE(range);
GET_IR_NODE(unsqueeze2);
GET_IR_NODE(scale_1);
GET_IR_NODE(add);
GET_IR_NODE(flatten_contiguous_range);
GET_IR_NODE(shape_out);
GET_IR_NODE(slice_out);
GET_IR_NODE(div_out);
GET_IR_NODE(cast_0_out);
GET_IR_NODE(cast_1_out);
GET_IR_NODE(scale_0_out);
GET_IR_NODE(cast_2_out);
GET_IR_NODE(range_out);
GET_IR_NODE(unsqueeze2_out);
GET_IR_NODE(scale_1_out);
GET_IR_NODE(add_x);
GET_IR_NODE(add_out);
flatten_contiguous_range->Op()->RenameInput(add_out->Name(), add_x->Name());
std::unordered_set<const Node*> delete_nodes{
shape, slice, div, cast_0, cast_1,
scale_0, cast_2, range, unsqueeze2, scale_1,
add, shape_out, slice_out, div_out, cast_0_out,
cast_1_out, scale_0_out, cast_2_out, range_out, unsqueeze2_out,
scale_1_out, add_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void OneBeamSizeFusePass::RemoveBeamSearchAssociatedOps(
ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::BeamSearchAssociatedOpsPattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle RemoveBeamSearchAssociatedOps";
GET_IR_NODE(lod_reset_0);
GET_IR_NODE(lod_reset_1);
GET_IR_NODE(beam_search);
GET_IR_NODE(write_to_array_0);
GET_IR_NODE(write_to_array_1);
GET_IR_NODE(is_empty);
GET_IR_NODE(logical_not);
GET_IR_NODE(cast);
GET_IR_NODE(lod_reset_0_out);
GET_IR_NODE(lod_reset_1_out);
GET_IR_NODE(beam_search_parent_idx);
GET_IR_NODE(beam_search_selected_ids);
GET_IR_NODE(beam_search_selected_scores);
GET_IR_NODE(is_empty_out);
GET_IR_NODE(logical_not_out);
GET_IR_NODE(cast_out);
auto* block = lod_reset_0->Op()->Block();
auto* scope = param_scope();
write_to_array_0->Op()->RenameInput(beam_search_selected_ids->Name(),
lod_reset_0_out->Name());
IR_NODE_LINK_TO(lod_reset_0_out, write_to_array_0);
write_to_array_1->Op()->RenameInput(beam_search_selected_scores->Name(),
lod_reset_1_out->Name());
IR_NODE_LINK_TO(lod_reset_1_out, write_to_array_1);
// Transform is_empty to not_equal
is_empty->RenameOp("not_equal");
auto* not_equal = is_empty;
auto* not_equal_desc = not_equal->Op();
not_equal_desc->RenameInput(beam_search_selected_ids->Name(),
lod_reset_0_out->Name());
not_equal_desc->RenameOutput(is_empty_out->Name(), logical_not_out->Name());
std::string not_equal_y_name = lod_reset_0_out->Name() + "_not_equal_y";
not_equal_desc->SetInput("Y", {not_equal_y_name});
VarDesc not_equal_y_desc(not_equal_y_name);
not_equal_y_desc.SetPersistable(true);
not_equal_y_desc.SetShape({static_cast<int64_t>(1)});
not_equal_y_desc.SetDataType(proto::VarType::Type::VarType_Type_INT64);
auto* not_equal_y = graph->CreateVarNode(&not_equal_y_desc);
auto* block_not_equal_y_desc = block->Var(not_equal_y_name);
block_not_equal_y_desc->SetPersistable(not_equal_y_desc.Persistable());
block_not_equal_y_desc->SetShape(not_equal_y_desc.GetShape());
block_not_equal_y_desc->SetDataType(not_equal_y_desc.GetDataType());
auto* not_equal_y_tensor =
scope->Var(not_equal_y_name)->GetMutable<phi::DenseTensor>();
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
not_equal_y_tensor->Resize({1});
not_equal_y_tensor->set_type(phi::DataType::INT64);
auto* not_equal_y_data = cpu_ctx->Alloc<int64_t>(not_equal_y_tensor);
not_equal_y_data[0] = beam_search->Op()->GetAttrIfExists<int>("end_id");
IR_NODE_LINK_TO(not_equal_y, not_equal);
// cast_out is 0
cast_out->Var()->SetPersistable(true);
auto* cast_out_tensor =
scope->Var(cast_out->Name())->GetMutable<phi::DenseTensor>();
cast_out_tensor->Resize({1});
cast_out_tensor->set_type(phi::DataType::INT64);
auto* cast_out_data = cpu_ctx->Alloc<int64_t>(cast_out_tensor);
cast_out_data[0] = 0;
std::unordered_set<const Node*> delete_nodes{
beam_search,
logical_not,
cast,
beam_search_parent_idx,
beam_search_selected_ids,
beam_search_selected_scores,
is_empty_out,
};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void OneBeamSizeFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
if (!OnlyOneBeamSearchAndOneBeamSize(graph)) return;
RemoveAssignGather(graph);
FoldShapeAssociatedOps(graph);
RemoveBeamSearchAssociatedOps(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(one_beam_size_fuse_pass,
paddle::framework::ir::OneBeamSizeFusePass);
REGISTER_PASS_CAPABILITY(one_beam_size_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"beam_search", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class OneBeamSizeFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
/*
Origin subgraph:
fused_multi_transformer
| | |
assign assign ...
| | |
gather gather ...
Fused subgraph:
fused_multi_transformer
*/
void RemoveAssignGather(ir::Graph* graph) const;
/*
Origin subgraph:
shape
/ | \
/ | \
elementwise_div | scale
| | |
cast cast cast
\ | /
range
|
unsqueeze2
|
scale (add_x)
| /
elementwise_add
|
flatten_contiguous_range
Fused subgraph:
(add_x)
|
flatten_contiguous_range
*/
void FoldShapeAssociatedOps(ir::Graph* graph) const;
/*
Origin subgraph:
lod_reset lod_reset
| |
(ids) (scores)
\ |
beam_search
/ | \
/ | \
/ | \
(selected_ids) (selected_scores) (parent_idx)
/ | | |
write_to_array is_empty write_to_array cast
| |
| (cast_out)
| |
logical_not write_to_array
Fused subgraph:
lod_reset lod_reset (cast_out: fill 0)
| | |
(ids) (scores) write_to_array
/ \ |
write_to_array not_equal write_to_array
*/
void RemoveBeamSearchAssociatedOps(ir::Graph* graph) const;
const std::string name_scope_{"one_beam_size_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
VarDesc* Data(paddle::framework::BlockDesc* block,
std::string name,
std::vector<int64_t> shape = {},
bool is_persistable = false,
proto::VarType::Type data_type = proto::VarType::FP32) {
auto* var = block->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
var->SetDataType(data_type);
var->SetShape(shape);
var->SetPersistable(is_persistable);
return var;
}
TEST(RemoveAssignGather, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
OpDesc* beam_search_op = block->AppendOp();
beam_search_op->SetType("beam_search");
beam_search_op->SetAttr("beam_size", 1);
auto* x = Data(block, "fused_multi_transformer_x", {1, 1, 1536});
auto* cache_kv =
Data(block, "fused_multi_transformer_cache_kv", {2, 1, 24, 512, 64});
OpDesc* fused_multi_transformer_op = block->AppendOp();
fused_multi_transformer_op->SetType("fused_multi_transformer");
fused_multi_transformer_op->SetInput("X", {x->Name()});
fused_multi_transformer_op->SetInput("CacheKV", {cache_kv->Name()});
fused_multi_transformer_op->SetOutput("CacheKVOut", {cache_kv->Name()});
auto* assign_out = Data(block, "assign_out", cache_kv->GetShape());
OpDesc* assign_op = block->AppendOp();
assign_op->SetType("assign");
assign_op->SetInput("X", {cache_kv->Name()});
assign_op->SetOutput("Out", {assign_out->Name()});
OpDesc* gather_op = block->AppendOp();
gather_op->SetType("gather");
gather_op->SetInput("X", {assign_out->Name()});
gather_op->SetOutput("Out", {cache_kv->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto pass = PassRegistry::Instance().Get("one_beam_size_fuse_pass");
pass->Apply(graph.get());
auto assign_num = GetNumOpNodes(graph, "assign");
auto gather_num = GetNumOpNodes(graph, "gather");
PADDLE_ENFORCE_EQ(assign_num,
0,
platform::errors::PreconditionNotMet(
"assign op should be removed from the graph."));
PADDLE_ENFORCE_EQ(gather_num,
0,
platform::errors::PreconditionNotMet(
"gather op should be removed from the graph."));
}
TEST(FoldShapeAssociatedOps, basic) {
Layers layers;
auto* block = layers.Block();
OpDesc* beam_search_op = block->AppendOp();
beam_search_op->SetType("beam_search");
beam_search_op->SetAttr("beam_size", 1);
auto* shape_x = layers.data("shape_x", {1, 46256});
auto* shape_out = layers.shape(shape_x);
auto* slice_out = layers.slice(shape_out, {0}, {0}, {1});
auto* div_out = layers.elementwise_div(slice_out, slice_out);
auto* cast0_out = layers.cast(div_out);
auto* cast1_out = layers.cast(slice_out);
auto* scale0_out = layers.scale(slice_out);
auto* cast2_out = layers.cast(scale0_out);
auto* range_out = layers.range(cast2_out, cast1_out, cast0_out);
auto* unsqueeze2_out = layers.unsqueeze2(range_out);
auto* scale1_out = layers.scale(unsqueeze2_out);
auto* add_x = layers.data("add_x", {1, 2});
auto* add_out = layers.elementwise_add(add_x, scale1_out);
layers.flatten_contiguous_range(add_out);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("one_beam_size_fuse_pass");
pass->Apply(graph.get());
auto ops_num = GetNumOpNodes(graph);
PADDLE_ENFORCE_EQ(
ops_num,
2,
platform::errors::PreconditionNotMet(
"graph should only have 2 op nodes, but received %d.", ops_num));
}
TEST(RemoveBeamSearchAssociatedOps, basic) {
Layers layers;
auto* lod_reset_0_x = layers.data("lod_reset_0_x");
auto* lod_reset_0_y = layers.data("lod_reset_0_y");
auto* lod_reset_0_out = layers.lod_reset(lod_reset_0_x, lod_reset_0_y);
auto* lod_reset_1_x = layers.data("lod_reset_1_x");
auto* lod_reset_1_y = layers.data("lod_reset_1_y");
auto* lod_reset_1_out = layers.lod_reset(lod_reset_1_x, lod_reset_1_y);
auto* pre_ids = layers.data("pre_ids");
auto* pre_scores = layers.data("pre_scores");
auto beam_search_outs =
layers.beam_search(lod_reset_0_out, lod_reset_1_out, pre_ids, pre_scores);
auto* parent_idx = beam_search_outs[0];
auto* selected_ids = beam_search_outs[1];
auto* selected_scores = beam_search_outs[2];
auto* write_to_array_0_i = layers.data("write_to_array_0_i");
layers.write_to_array({selected_ids}, write_to_array_0_i);
auto* write_to_array_1_i = layers.data("write_to_array_1_i");
layers.write_to_array({selected_scores}, write_to_array_1_i);
auto* is_empty_out = layers.is_empty(selected_ids);
layers.logical_not(is_empty_out);
layers.cast(parent_idx);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto* param_scope = new Scope();
graph->Set("__param_scope__", param_scope);
auto pass = PassRegistry::Instance().Get("one_beam_size_fuse_pass");
pass->Apply(graph.get());
auto beam_search_num = GetNumOpNodes(graph, "beam_search");
PADDLE_ENFORCE_EQ(beam_search_num,
0,
platform::errors::PreconditionNotMet(
"beam_search op should be removed from the graph."));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(one_beam_size_fuse_pass);
...@@ -105,8 +105,7 @@ size_t HashTensor(const phi::DenseTensor& in) { ...@@ -105,8 +105,7 @@ size_t HashTensor(const phi::DenseTensor& in) {
template size_t HashTensor<int16_t>(const phi::DenseTensor& in); template size_t HashTensor<int16_t>(const phi::DenseTensor& in);
template size_t HashTensor<float>(const phi::DenseTensor& in); template size_t HashTensor<float>(const phi::DenseTensor& in);
std::string GetPrefixWithoutHash(const std::string& name, std::string GetPrefixWithoutHash(const std::string& name) {
const phi::DenseTensor& tensor) {
std::size_t found = name.find("_#"); std::size_t found = name.find("_#");
return found == std::string::npos ? name : name.substr(0, found); return found == std::string::npos ? name : name.substr(0, found);
} }
...@@ -128,7 +127,7 @@ void PrepareWeight(Graph* graph, ...@@ -128,7 +127,7 @@ void PrepareWeight(Graph* graph,
size_t dst_hash = HashTensor<T>(dst_tensor); size_t dst_hash = HashTensor<T>(dst_tensor);
size_t dst_max_hash = HashTensor<float>(dst_max_tensor); size_t dst_max_hash = HashTensor<float>(dst_max_tensor);
std::string pre_name = GetPrefixWithoutHash(src_name, *src_tensor); std::string pre_name = GetPrefixWithoutHash(src_name);
std::string dst_name = pre_name + "_#" + std::to_string(dst_hash); std::string dst_name = pre_name + "_#" + std::to_string(dst_hash);
std::string dst_max_name = pre_name + "_max_#" + std::to_string(dst_max_hash); std::string dst_max_name = pre_name + "_max_#" + std::to_string(dst_max_hash);
*dst = FindNodeWithName(graph, dst_name); *dst = FindNodeWithName(graph, dst_name);
...@@ -206,7 +205,7 @@ void PrepareBias( ...@@ -206,7 +205,7 @@ void PrepareBias(
phi::DenseTensor dst_tensor; phi::DenseTensor dst_tensor;
CastToFp32(src_tensor, &dst_tensor); CastToFp32(src_tensor, &dst_tensor);
size_t dst_hash = HashTensor<float>(dst_tensor); size_t dst_hash = HashTensor<float>(dst_tensor);
std::string pre_name = GetPrefixWithoutHash(src_name, *src_tensor); std::string pre_name = GetPrefixWithoutHash(src_name);
std::string dst_name = pre_name + "_#" + std::to_string(dst_hash); std::string dst_name = pre_name + "_#" + std::to_string(dst_hash);
*dst = FindNodeWithName(graph, dst_name); *dst = FindNodeWithName(graph, dst_name);
if (*dst == nullptr) { if (*dst == nullptr) {
......
...@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"embedding_with_eltwise_add_xpu_fuse_pass", "embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass", "multi_encoder_xpu_slice_fuse_pass",
"one_beam_size_fuse_pass",
"fused_multi_transformer_xpu_quant_pass", "fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass", "fc_xpu_fuse_pass",
"link_xpu_op_max_pass", "link_xpu_op_max_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册