未验证 提交 5914b18a 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] support transformer generation: some passes (#42664)

* [Paddle-Inference] support transformer generation: some passes
上级 a7926ef2
......@@ -107,6 +107,9 @@ if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_skip_layernorm_fuse_pass inference)
pass_library(set_transformer_input_convert_pass inference)
pass_library(remove_padding_recover_padding_pass inference)
pass_library(delete_remove_padding_recover_padding_pass inference)
endif()
if(WITH_GPU OR WITH_ROCM)
......@@ -161,6 +164,7 @@ if(WITH_IPU)
pass_library(infer_shape_pass base DIR ipu)
pass_library(delete_scale_op_pass base DIR ipu)
pass_library(avg_shard_pass base DIR ipu)
pass_library(transfer_cast_op_pass base DIR ipu)
endif()
cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
void RecoverPadding::operator()() {
// Create nodes for recover_padding.
auto *recover_padding_input =
pattern->NewNode(recover_padding_input_repr())
->assert_is_op_input("recover_padding", "Input");
auto *recover_padding_op = pattern->NewNode(recover_padding_op_repr())
->assert_is_op("recover_padding");
auto *recover_padding_out =
pattern->NewNode(recover_padding_out_repr())
->assert_is_op_output("recover_padding", "Out");
// Add links for recover_padding op.
recover_padding_op->LinksFrom({recover_padding_input})
.LinksTo({recover_padding_out});
}
} // namespace patterns
void DeleteRemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init(name_scope_, graph);
int found_subgraph_count = 0;
//
GraphPatternDetector gpd;
patterns::RecoverPadding recover_padding(
gpd.mutable_pattern(), "delete_remove_padding_recover_padding_pass");
recover_padding();
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
VLOG(3) << "delete_remove_padding_recover_padding_pass";
GET_IR_NODE_FROM_SUBGRAPH(recover_padding_input, recover_padding_input,
recover_padding);
GET_IR_NODE_FROM_SUBGRAPH(recover_padding_op, recover_padding_op,
recover_padding);
GET_IR_NODE_FROM_SUBGRAPH(recover_padding_out, recover_padding_out,
recover_padding);
std::unordered_set<const Node *> del_node_set;
bool delete_recover_padding = true;
for (size_t i = 0; i < recover_padding_out->outputs.size(); ++i) {
if (recover_padding_out->outputs[i]->Name() ==
"remove_padding") { // op_node
auto *remove_padding_out_node =
recover_padding_out->outputs[i]->outputs[0]; // var_node
auto *out_op_node = remove_padding_out_node->outputs[0]; // op_node
IR_NODE_LINK_TO(recover_padding_input, out_op_node);
del_node_set.insert(recover_padding_out->outputs[i]);
del_node_set.insert(remove_padding_out_node);
out_op_node->Op()->RenameInput(remove_padding_out_node->Name(),
recover_padding_input->Name());
found_subgraph_count++;
} else {
delete_recover_padding = false;
}
}
if (delete_recover_padding) {
del_node_set.insert(recover_padding_op);
del_node_set.insert(recover_padding_out);
}
GraphSafeRemoveNodes(graph, del_node_set);
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_remove_padding_recover_padding_pass,
paddle::framework::ir::DeleteRemovePaddingRecoverPaddingPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct RecoverPadding : public PatternBase {
RecoverPadding(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "recover_padding") {}
void operator()();
PATTERN_DECL_NODE(recover_padding_input);
PATTERN_DECL_NODE(recover_padding_op);
PATTERN_DECL_NODE(recover_padding_out);
};
} // namespace patterns
class DeleteRemovePaddingRecoverPaddingPass : public FusePassBase {
public:
DeleteRemovePaddingRecoverPaddingPass() {}
virtual ~DeleteRemovePaddingRecoverPaddingPass() {}
protected:
void ApplyImpl(Graph *graph) const;
const std::string name_scope_{"delete_remove_padding_recover_padding_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
void SkipLayernorm::operator()() {
// Create nodes for skip_layernorm.
auto* skip_layernorm_x = pattern->NewNode(skip_layernorm_x_repr())
->assert_is_op_input("skip_layernorm", "X");
auto* skip_layernorm_y = pattern->NewNode(skip_layernorm_y_repr())
->assert_is_op_input("skip_layernorm", "Y");
auto* skip_layernorm_op = pattern->NewNode(skip_layernorm_op_repr())
->assert_is_op("skip_layernorm");
auto* skip_layernorm_out = pattern->NewNode(skip_layernorm_out_repr())
->assert_is_op_output("skip_layernorm", "Out");
// Add links for skip_layernorm op.
skip_layernorm_op->LinksFrom({skip_layernorm_x, skip_layernorm_y})
.LinksTo({skip_layernorm_out});
}
void MultiheadMatmul::operator()() {
// Create nodes for multihead_matmul.
auto* multihead_matmul_input =
pattern->NewNode(multihead_matmul_input_repr())
->assert_is_op_input("multihead_matmul", "Input");
auto* multihead_matmul_op = pattern->NewNode(multihead_matmul_op_repr())
->assert_is_op("multihead_matmul");
auto* multihead_matmul_out =
pattern->NewNode(multihead_matmul_out_repr())
->assert_is_op_output("multihead_matmul", "Out");
// Add links for multihead_matmul op.
multihead_matmul_op->LinksFrom({multihead_matmul_input})
.LinksTo({multihead_matmul_out});
}
void Fc::operator()() {
// Create nodes for fc.
auto* fc_input =
pattern->NewNode(fc_input_repr())->assert_is_op_input("fc", "Input");
auto* fc_op = pattern->NewNode(fc_op_repr())->assert_is_op("fc");
auto* fc_out =
pattern->NewNode(fc_out_repr())->assert_is_op_output("fc", "Out");
// Add links for fc op.
fc_op->LinksFrom({fc_input}).LinksTo({fc_out});
}
void Activation::operator()() {
// Create nodes for activation.
std::unordered_set<std::string> activation_ops{"relu", "sigmoid", "tanh"};
auto* activation_input = pattern->NewNode(activation_input_repr())
->assert_is_ops_input(activation_ops);
auto* activation_op =
pattern->NewNode(activation_op_repr())->assert_is_ops(activation_ops);
auto* activation_out = pattern->NewNode(activation_out_repr())
->assert_is_ops_output(activation_ops);
// Add links for activation op.
activation_op->LinksFrom({activation_input}).LinksTo({activation_out});
}
} // namespace patterns
void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
int found_subgraph_count = 0;
// Create an remove_padding op node
auto insert_remove_padding_op = [&](Node* input_node, Node* op_node) {
// create op, var in graph
OpDesc remove_padding;
std::string remove_padding_out_name =
input_node->Name() + ".remove_padding";
VarDesc remove_padding_out(remove_padding_out_name);
remove_padding_out.SetDataType(input_node->Var()->GetDataType());
remove_padding_out.SetShape(input_node->Var()->GetShape());
remove_padding_out.SetPersistable(false);
// remove_padding_op
remove_padding.SetType("remove_padding");
// input
remove_padding.SetInput("Input", {input_node->Name()});
// output
remove_padding.SetOutput("Out", {remove_padding_out_name});
auto remove_padding_op_node = graph->CreateOpNode(&remove_padding);
auto remove_padding_out_node = graph->CreateVarNode(&remove_padding_out);
// replace link
for (size_t i = 0; i < input_node->outputs.size(); ++i) {
if (input_node->outputs[i] == op_node) {
input_node->outputs[i] = remove_padding_op_node;
remove_padding_op_node->inputs.push_back(input_node);
}
}
// link node
IR_NODE_LINK_TO(remove_padding_op_node, remove_padding_out_node);
// replace link
for (size_t i = 0; i < op_node->inputs.size(); ++i) {
if (op_node->inputs[i] == input_node) {
op_node->inputs[i] = remove_padding_out_node;
remove_padding_out_node->outputs.push_back(op_node);
}
}
// create variable in scope
scope->Var(remove_padding_out_name);
auto* remove_padding_out_tensor =
scope->FindVar(remove_padding_out_name)->GetMutable<LoDTensor>();
remove_padding_out_tensor->mutable_data<float>(platform::CUDAPlace());
// rename
op_node->Op()->RenameInput(input_node->Name(),
remove_padding_out_node->Name());
};
// create an remove_padding op node
auto insert_recover_padding_op = [&](Node* op_node, Node* out_node) {
// create op, var in graph
OpDesc recover_padding;
std::string recover_padding_input_name =
out_node->Name() + ".recover_padding";
VarDesc recover_padding_input(recover_padding_input_name);
recover_padding_input.SetDataType(out_node->Var()->GetDataType());
recover_padding_input.SetShape(out_node->Var()->GetShape());
recover_padding_input.SetPersistable(false);
// recover_padding_op
recover_padding.SetType("recover_padding");
// input
recover_padding.SetInput("Input", {recover_padding_input_name});
// output
recover_padding.SetOutput("Out", {out_node->Name()});
auto recover_padding_op_node = graph->CreateOpNode(&recover_padding);
auto recover_padding_input_node =
graph->CreateVarNode(&recover_padding_input);
// replace link
for (size_t i = 0; i < op_node->outputs.size(); ++i) {
if (op_node->outputs[i] == out_node) {
op_node->outputs[i] = recover_padding_input_node;
recover_padding_input_node->inputs.push_back(op_node);
}
}
// link node
IR_NODE_LINK_TO(recover_padding_input_node, recover_padding_op_node);
// replace link
for (size_t i = 0; i < out_node->inputs.size(); ++i) {
if (out_node->inputs[i] == op_node) {
out_node->inputs[i] = recover_padding_op_node;
recover_padding_op_node->outputs.push_back(out_node);
}
}
// create variable in scope
scope->Var(recover_padding_input_name);
auto* recover_padding_input_tensor =
scope->FindVar(recover_padding_input_name)->GetMutable<LoDTensor>();
recover_padding_input_tensor->mutable_data<float>(platform::CUDAPlace());
// rename
op_node->Op()->RenameOutput(out_node->Name(), recover_padding_input_name);
};
GraphPatternDetector gpd1;
patterns::SkipLayernorm skip_layernorm(gpd1.mutable_pattern(),
"remove_padding_recover_padding_pass");
skip_layernorm();
auto handler1 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"skip_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_x, skip_layernorm_x,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_y, skip_layernorm_y,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_op, skip_layernorm_op,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_out, skip_layernorm_out,
skip_layernorm);
insert_remove_padding_op(skip_layernorm_x, skip_layernorm_op);
insert_remove_padding_op(skip_layernorm_y, skip_layernorm_op);
insert_recover_padding_op(skip_layernorm_op, skip_layernorm_out);
found_subgraph_count++;
};
gpd1(graph, handler1);
GraphPatternDetector gpd2;
patterns::MultiheadMatmul multihead_matmul(
gpd2.mutable_pattern(), "remove_padding_recover_padding_pass");
multihead_matmul();
auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"multihead_matmul";
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input,
multihead_matmul);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_op, multihead_matmul_op,
multihead_matmul);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out,
multihead_matmul);
insert_remove_padding_op(multihead_matmul_input, multihead_matmul_op);
insert_recover_padding_op(multihead_matmul_op, multihead_matmul_out);
found_subgraph_count++;
};
gpd2(graph, handler2);
GraphPatternDetector gpd3;
patterns::Fc fc(gpd3.mutable_pattern(),
"remove_padding_recover_padding_pass");
fc();
auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: fc";
GET_IR_NODE_FROM_SUBGRAPH(fc_input, fc_input, fc);
GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc);
insert_remove_padding_op(fc_input, fc_op);
insert_recover_padding_op(fc_op, fc_out);
found_subgraph_count++;
};
gpd3(graph, handler3);
GraphPatternDetector gpd4;
patterns::Activation activation(gpd4.mutable_pattern(),
"remove_padding_recover_padding_pass");
activation();
auto handler4 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3)
<< "remove_padding_recover_padding_pass for transformer: activation";
GET_IR_NODE_FROM_SUBGRAPH(activation_input, activation_input, activation);
GET_IR_NODE_FROM_SUBGRAPH(activation_op, activation_op, activation);
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, activation);
insert_remove_padding_op(activation_input, activation_op);
insert_recover_padding_op(activation_op, activation_out);
found_subgraph_count++;
};
gpd4(graph, handler4);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(remove_padding_recover_padding_pass,
paddle::framework::ir::RemovePaddingRecoverPaddingPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct SkipLayernorm : public PatternBase {
SkipLayernorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "skip_layernorm") {}
void operator()();
PATTERN_DECL_NODE(skip_layernorm_x);
PATTERN_DECL_NODE(skip_layernorm_y);
PATTERN_DECL_NODE(skip_layernorm_op);
PATTERN_DECL_NODE(skip_layernorm_out);
};
struct MultiheadMatmul : public PatternBase {
MultiheadMatmul(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul") {}
void operator()();
PATTERN_DECL_NODE(multihead_matmul_input);
PATTERN_DECL_NODE(multihead_matmul_op);
PATTERN_DECL_NODE(multihead_matmul_out);
};
struct Fc : public PatternBase {
Fc(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "fc") {}
void operator()();
PATTERN_DECL_NODE(fc_input);
PATTERN_DECL_NODE(fc_op);
PATTERN_DECL_NODE(fc_out);
};
struct Activation : public PatternBase {
Activation(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "activation") {}
void operator()();
PATTERN_DECL_NODE(activation_input);
PATTERN_DECL_NODE(activation_op);
PATTERN_DECL_NODE(activation_out);
};
} // namespace patterns
class RemovePaddingRecoverPaddingPass : public FusePassBase {
public:
RemovePaddingRecoverPaddingPass() {}
virtual ~RemovePaddingRecoverPaddingPass() {}
protected:
void ApplyImpl(Graph *graph) const;
const std::string name_scope_{"remove_padding_recover_padding_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/set_transformer_input_convert_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
SetTransformerInputConvertPass::SetTransformerInputConvertPass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.End();
}
namespace patterns {
void SetTransformerInputConvert::operator()() {
std::unordered_set<std::string> lookup_table_ops{"lookup_table",
"lookup_table_v2"};
// Create nodes for lookup_table1 op.
auto *lookup_table1_x = pattern->NewNode(lookup_table1_x_repr())
->assert_is_ops_input(lookup_table_ops, "Ids");
auto *lookup_table1_w = pattern->NewNode(lookup_table1_w_repr())
->assert_is_ops_input(lookup_table_ops, "W");
auto *lookup_table1_op =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(lookup_table_ops);
auto *lookup_table1_out = pattern->NewNode(lookup_table1_out_repr())
->assert_is_ops_output(lookup_table_ops)
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
// Create nodes for lookup_table2 op.
auto *lookup_table2_x = pattern->NewNode(lookup_table2_x_repr())
->assert_is_ops_input(lookup_table_ops, "Ids");
auto *lookup_table2_w = pattern->NewNode(lookup_table2_w_repr())
->assert_is_ops_input(lookup_table_ops, "W");
auto *lookup_table2_op =
pattern->NewNode(lookup_table2_repr())->assert_is_ops(lookup_table_ops);
auto *lookup_table2_out = pattern->NewNode(lookup_table2_out_repr())
->assert_is_ops_output(lookup_table_ops)
->AsIntermediate()
->assert_is_op_input("elementwise_add", "Y");
// Create nodes for elementwise_add op.
auto *elementwise_op =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
auto *elementwise_out = pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_only_output_of_op("elementwise_add");
// links nodes.
lookup_table1_op->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
lookup_table2_op->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out});
elementwise_op->LinksFrom({lookup_table1_out, lookup_table2_out})
.LinksTo({elementwise_out});
}
} // namespace patterns
void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init(name_scope_, graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
patterns::SetTransformerInputConvert fused_pattern(
gpd.mutable_pattern(), "transformer_input_convert_pass");
fused_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "transformer_input_convert_pass in op compat failed.";
return;
}
VLOG(3) << "transformer_input_convert_pass for pos_id, max_seqlen";
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, fused_pattern);
// create op, var in graph
OpDesc new_desc;
new_desc.SetType("transformer_input_convert");
// inputs
new_desc.SetInput("X", {lookup_table2_x->Name()});
// outputs
std::vector<std::string> output_0 = {"pos_id_tensor"};
std::vector<std::string> output_1 = {"max_seqlen_tensor"};
new_desc.SetOutput("PosId", output_0);
new_desc.SetOutput("MaxSeqlen", output_1);
std::string transformer_input_convert_out0_name = "pos_id_tensor";
std::string transformer_input_convert_out1_name = "max_seqlen_tensor";
VarDesc transformer_input_convert_out0(transformer_input_convert_out0_name);
VarDesc transformer_input_convert_out1(transformer_input_convert_out1_name);
transformer_input_convert_out0.SetDataType(proto::VarType::INT32);
transformer_input_convert_out1.SetDataType(proto::VarType::INT32);
transformer_input_convert_out0.SetShape({-1});
transformer_input_convert_out1.SetShape({-1});
transformer_input_convert_out0.SetPersistable(false);
transformer_input_convert_out1.SetPersistable(false);
auto new_op_node = graph->CreateOpNode(&new_desc);
auto transformer_input_convert_out0_node =
graph->CreateVarNode(&transformer_input_convert_out0);
auto transformer_input_convert_out1_node =
graph->CreateVarNode(&transformer_input_convert_out1);
// needn't create variable in scope
IR_NODE_LINK_TO(lookup_table2_x, new_op_node);
IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out0_node);
IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out1_node);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(set_transformer_input_convert_pass,
paddle::framework::ir::SetTransformerInputConvertPass);
REGISTER_PASS_CAPABILITY(set_transformer_input_convert_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("lookup_table", 1)
.LE("lookup_table_v2", 1)
.LE("elementweise_add", 1));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
// in_var emb in_var emb
// | | | |
// lookup_table lookup_table
// | |
// lkt_var lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct SetTransformerInputConvert : public PatternBase {
SetTransformerInputConvert(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "transformer_input_convert") {}
void operator()();
// declare operator node's name
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table2);
PATTERN_DECL_NODE(elementwise);
// declare variable node's name
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table2_w);
PATTERN_DECL_NODE(lookup_table2_out);
PATTERN_DECL_NODE(elementwise_out);
};
} // namespace patterns
class SetTransformerInputConvertPass : public FusePassBase {
public:
SetTransformerInputConvertPass();
virtual ~SetTransformerInputConvertPass() {}
protected:
void ApplyImpl(Graph *graph) const;
const std::string name_scope_{"transformer_input_convert_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -377,12 +377,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
trt_engine->SetUseDLA(Get<bool>("trt_use_dla"));
trt_engine->SetDLACore(Get<int>("trt_dla_core"));
trt_engine->SetUseInspector(Get<bool>("use_inspector"));
trt_engine->SetWithErnie(
(graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)) ||
(graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)));
trt_engine->SetWithErnie(graph->Has(framework::ir::kMultiheadMatmulPass));
if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData(
......
......@@ -98,6 +98,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"multihead_matmul_fuse_pass_v3", //
"skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", //
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
"trt_squeeze2_matmul_fuse_pass", //
......@@ -108,6 +109,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
// "remove_padding_recover_padding_pass", //
// "delete_remove_padding_recover_padding_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册