未验证 提交 2810dfea 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-Inference] new general transformer inference support (#43077)

* new general transformer inference support
上级 0cb9dae5
...@@ -107,6 +107,9 @@ target_link_libraries(generate_pass pass_desc_proto) ...@@ -107,6 +107,9 @@ target_link_libraries(generate_pass pass_desc_proto)
if(WITH_TENSORRT) if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference) pass_library(trt_map_matmul_to_mul_pass inference)
pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(trt_multihead_matmul_fuse_pass inference)
pass_library(trt_skip_layernorm_fuse_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_skip_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference)
pass_library(set_transformer_input_convert_pass inference) pass_library(set_transformer_input_convert_pass inference)
......
...@@ -430,13 +430,15 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { ...@@ -430,13 +430,15 @@ void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
bool enable_int8 = Get<bool>("enable_int8"); bool enable_int8 = Get<bool>("enable_int8");
bool use_oss = Get<bool>("use_oss"); bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved"); bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape"); bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!(enable_int8 && use_oss && with_interleaved && with_dynamic_shape)) { if (!(enable_int8 && use_varseqlen && with_interleaved &&
with_dynamic_shape)) {
VLOG(4) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, " VLOG(4) << "preln_embedding_eltwise_layernorm_fuse_pass need: use_trt, "
"enable_int8, " "enable_int8, "
"use_oss, with_interleaved, with_dynamic_shape. Stop this pass, " "use_varseqlen, with_interleaved, with_dynamic_shape. Stop this "
"pass, "
"please reconfig."; "please reconfig.";
return; return;
} }
......
...@@ -109,12 +109,13 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -109,12 +109,13 @@ void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_skip_layernorm_fuse", graph); FusePassBase::Init("preln_skip_layernorm_fuse", graph);
bool enable_int8 = Get<bool>("enable_int8"); bool enable_int8 = Get<bool>("enable_int8");
bool use_oss = Get<bool>("use_oss"); bool use_varseqlen = Get<bool>("use_varseqlen");
bool with_interleaved = Get<bool>("with_interleaved"); bool with_interleaved = Get<bool>("with_interleaved");
bool with_dynamic_shape = Get<bool>("with_dynamic_shape"); bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
if (!(enable_int8 && use_oss && with_interleaved && with_dynamic_shape)) { if (!(enable_int8 && use_varseqlen && with_interleaved &&
with_dynamic_shape)) {
VLOG(4) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, " VLOG(4) << "preln_skip_layernorm_fuse_pass need: use_trt, enable_int8, "
"use_oss, " "use_varseqlen, "
"with_interleaved, with_dynamic_shape. Stop this pass, please " "with_interleaved, with_dynamic_shape. Stop this pass, please "
"reconfig. "; "reconfig. ";
return; return;
......
...@@ -22,6 +22,19 @@ namespace paddle { ...@@ -22,6 +22,19 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace patterns { namespace patterns {
void EmbEltwiseLayernorm::operator()() {
// Create nodes for fused_embedding_eltwise_layernorm.
auto* emb_elt_layernorm_op =
pattern->NewNode(emb_elt_layernorm_op_repr())
->assert_is_op("fused_embedding_eltwise_layernorm");
auto* emb_elt_layernorm_out =
pattern->NewNode(emb_elt_layernorm_out_repr())
->assert_is_op_output("fused_embedding_eltwise_layernorm", "Out");
// Add links for fused_embedding_eltwise_layernorm op.
emb_elt_layernorm_op->LinksTo({emb_elt_layernorm_out});
}
void SkipLayernorm::operator()() { void SkipLayernorm::operator()() {
// Create nodes for skip_layernorm. // Create nodes for skip_layernorm.
auto* skip_layernorm_x = pattern->NewNode(skip_layernorm_x_repr()) auto* skip_layernorm_x = pattern->NewNode(skip_layernorm_x_repr())
...@@ -59,16 +72,12 @@ void Fc::operator()() { ...@@ -59,16 +72,12 @@ void Fc::operator()() {
auto* fc_input = auto* fc_input =
pattern->NewNode(fc_input_repr())->assert_is_op_input("fc", "Input"); pattern->NewNode(fc_input_repr())->assert_is_op_input("fc", "Input");
auto* fc_op = pattern->NewNode(fc_op_repr())->assert_is_op("fc"); auto* fc_op = pattern->NewNode(fc_op_repr())->assert_is_op("fc");
auto* fc_out = fc_op->LinksFrom({fc_input});
pattern->NewNode(fc_out_repr())->assert_is_op_output("fc", "Out");
// Add links for fc op.
fc_op->LinksFrom({fc_input}).LinksTo({fc_out});
} }
void Activation::operator()() { void Activation::operator()() {
// Create nodes for activation. // Create nodes for activation.
std::unordered_set<std::string> activation_ops{"relu", "sigmoid", "tanh"}; std::unordered_set<std::string> activation_ops{"relu", "sigmoid", "gelu"};
auto* activation_input = pattern->NewNode(activation_input_repr()) auto* activation_input = pattern->NewNode(activation_input_repr())
->assert_is_ops_input(activation_ops); ->assert_is_ops_input(activation_ops);
auto* activation_op = auto* activation_op =
...@@ -82,6 +91,18 @@ void Activation::operator()() { ...@@ -82,6 +91,18 @@ void Activation::operator()() {
} // namespace patterns } // namespace patterns
void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
bool use_varseqlen = Get<bool>("use_varseqlen");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "" &&
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)) {
VLOG(3) << "start varseqlen remove_padding_recover_padding_pass";
} else {
return;
}
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
...@@ -91,14 +112,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -91,14 +112,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
// Create an remove_padding op node // Create an remove_padding op node
auto insert_remove_padding_op = [&](Node* input_node, Node* op_node) { auto insert_remove_padding_op = [&](Node* input_node, Node* op_node) {
// create op, var in graph // create op, var in graph
OpDesc remove_padding; OpDesc remove_padding(op_node->Op()->Block());
std::string remove_padding_out_name = std::string remove_padding_out_name =
input_node->Name() + ".remove_padding"; input_node->Name() + ".remove_padding";
auto* remove_padding_out =
VarDesc remove_padding_out(remove_padding_out_name); op_node->Op()->Block()->Var(remove_padding_out_name);
remove_padding_out.SetDataType(input_node->Var()->GetDataType()); remove_padding_out->SetDataType(input_node->Var()->GetDataType());
remove_padding_out.SetShape(input_node->Var()->GetShape()); remove_padding_out->SetShape(input_node->Var()->GetShape());
remove_padding_out.SetPersistable(false); remove_padding_out->SetPersistable(false);
// remove_padding_op // remove_padding_op
remove_padding.SetType("remove_padding"); remove_padding.SetType("remove_padding");
...@@ -110,7 +131,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -110,7 +131,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
remove_padding.SetOutput("Out", {remove_padding_out_name}); remove_padding.SetOutput("Out", {remove_padding_out_name});
auto remove_padding_op_node = graph->CreateOpNode(&remove_padding); auto remove_padding_op_node = graph->CreateOpNode(&remove_padding);
auto remove_padding_out_node = graph->CreateVarNode(&remove_padding_out); auto remove_padding_out_node = graph->CreateVarNode(remove_padding_out);
// replace link // replace link
for (size_t i = 0; i < input_node->outputs.size(); ++i) { for (size_t i = 0; i < input_node->outputs.size(); ++i) {
...@@ -145,13 +166,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -145,13 +166,14 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
// create an remove_padding op node // create an remove_padding op node
auto insert_recover_padding_op = [&](Node* op_node, Node* out_node) { auto insert_recover_padding_op = [&](Node* op_node, Node* out_node) {
// create op, var in graph // create op, var in graph
OpDesc recover_padding; OpDesc recover_padding(op_node->Op()->Block());
std::string recover_padding_input_name = std::string recover_padding_input_name =
out_node->Name() + ".recover_padding"; out_node->Name() + ".recover_padding";
VarDesc recover_padding_input(recover_padding_input_name); auto* recover_padding_input =
recover_padding_input.SetDataType(out_node->Var()->GetDataType()); op_node->Op()->Block()->Var(recover_padding_input_name);
recover_padding_input.SetShape(out_node->Var()->GetShape()); recover_padding_input->SetDataType(out_node->Var()->GetDataType());
recover_padding_input.SetPersistable(false); recover_padding_input->SetShape(out_node->Var()->GetShape());
recover_padding_input->SetPersistable(false);
// recover_padding_op // recover_padding_op
recover_padding.SetType("recover_padding"); recover_padding.SetType("recover_padding");
...@@ -164,7 +186,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -164,7 +186,7 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
auto recover_padding_op_node = graph->CreateOpNode(&recover_padding); auto recover_padding_op_node = graph->CreateOpNode(&recover_padding);
auto recover_padding_input_node = auto recover_padding_input_node =
graph->CreateVarNode(&recover_padding_input); graph->CreateVarNode(recover_padding_input);
// replace link // replace link
for (size_t i = 0; i < op_node->outputs.size(); ++i) { for (size_t i = 0; i < op_node->outputs.size(); ++i) {
...@@ -195,39 +217,36 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -195,39 +217,36 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
op_node->Op()->RenameOutput(out_node->Name(), recover_padding_input_name); op_node->Op()->RenameOutput(out_node->Name(), recover_padding_input_name);
}; };
GraphPatternDetector gpd1; bool check_flag = true;
patterns::SkipLayernorm skip_layernorm(gpd1.mutable_pattern(),
"remove_padding_recover_padding_pass");
skip_layernorm();
auto handler1 = [&](const GraphPatternDetector::subgraph_t& subgraph, GraphPatternDetector gpd0;
patterns::EmbEltwiseLayernorm fused_embedding_eltwise_layernorm(
gpd0.mutable_pattern(), "remove_padding_recover_padding_pass");
fused_embedding_eltwise_layernorm();
auto handler0 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: " VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"skip_layernorm"; "fused_embedding_eltwise_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_x, skip_layernorm_x, GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_op, emb_elt_layernorm_op,
skip_layernorm); fused_embedding_eltwise_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_y, skip_layernorm_y, GET_IR_NODE_FROM_SUBGRAPH(emb_elt_layernorm_out, emb_elt_layernorm_out,
skip_layernorm); fused_embedding_eltwise_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_op, skip_layernorm_op,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_out, skip_layernorm_out,
skip_layernorm);
insert_remove_padding_op(skip_layernorm_x, skip_layernorm_op); insert_recover_padding_op(emb_elt_layernorm_op, emb_elt_layernorm_out);
insert_remove_padding_op(skip_layernorm_y, skip_layernorm_op);
insert_recover_padding_op(skip_layernorm_op, skip_layernorm_out);
found_subgraph_count++; found_subgraph_count++;
}; };
gpd1(graph, handler1); gpd0(graph, handler0);
GraphPatternDetector gpd2; GraphPatternDetector gpd1;
patterns::MultiheadMatmul multihead_matmul( patterns::MultiheadMatmul multihead_matmul(
gpd2.mutable_pattern(), "remove_padding_recover_padding_pass"); gpd1.mutable_pattern(), "remove_padding_recover_padding_pass");
multihead_matmul(); multihead_matmul();
auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph, std::vector<int64_t> multihead_matmul_input_shape;
auto handler1 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: " VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"multihead_matmul"; "multihead_matmul";
...@@ -239,11 +258,57 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -239,11 +258,57 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out,
multihead_matmul); multihead_matmul);
multihead_matmul_input_shape = multihead_matmul_input->Var()->GetShape();
insert_remove_padding_op(multihead_matmul_input, multihead_matmul_op); insert_remove_padding_op(multihead_matmul_input, multihead_matmul_op);
insert_recover_padding_op(multihead_matmul_op, multihead_matmul_out); insert_recover_padding_op(multihead_matmul_op, multihead_matmul_out);
found_subgraph_count++; found_subgraph_count++;
}; };
gpd1(graph, handler1);
GraphPatternDetector gpd2;
patterns::SkipLayernorm skip_layernorm(gpd2.mutable_pattern(),
"remove_padding_recover_padding_pass");
skip_layernorm();
auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(3) << "remove_padding_recover_padding_pass for transformer: "
"skip_layernorm";
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_x, skip_layernorm_x,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_y, skip_layernorm_y,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_op, skip_layernorm_op,
skip_layernorm);
GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_out, skip_layernorm_out,
skip_layernorm);
std::vector<int64_t> skip_layernorm_x_shape =
skip_layernorm_x->Var()->GetShape();
if (skip_layernorm_x_shape.size() != multihead_matmul_input_shape.size()) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
for (size_t i = 0; i < skip_layernorm_x_shape.size(); ++i) {
if (skip_layernorm_x_shape[i] != multihead_matmul_input_shape[i]) {
check_flag = false;
}
}
if (!check_flag) {
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
insert_remove_padding_op(skip_layernorm_x, skip_layernorm_op);
insert_remove_padding_op(skip_layernorm_y, skip_layernorm_op);
insert_recover_padding_op(skip_layernorm_op, skip_layernorm_out);
found_subgraph_count++;
};
gpd2(graph, handler2); gpd2(graph, handler2);
GraphPatternDetector gpd3; GraphPatternDetector gpd3;
...@@ -257,11 +322,39 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -257,11 +322,39 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(fc_input, fc_input, fc); GET_IR_NODE_FROM_SUBGRAPH(fc_input, fc_input, fc);
GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc); GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc);
insert_remove_padding_op(fc_input, fc_op); std::vector<int64_t> fc_input_shape = fc_input->Var()->GetShape();
insert_recover_padding_op(fc_op, fc_out); if ((fc_input_shape.size() != multihead_matmul_input_shape.size()) ||
(fc_input_shape.size() != 3)) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
if (fc_input_shape[0] != multihead_matmul_input_shape[0]) {
check_flag = false;
}
if (fc_input_shape[1] != multihead_matmul_input_shape[1]) {
check_flag = false;
}
if ((fc_input_shape[2] != multihead_matmul_input_shape[2]) &&
(fc_input_shape[2] != 4 * multihead_matmul_input_shape[2])) {
check_flag = false;
}
if (BOOST_GET_CONST(int, fc_op->Op()->GetAttr("in_num_col_dims")) != 2) {
check_flag = false;
}
if (!check_flag) {
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
fc_op->Op()->RemoveAttr("in_num_col_dims");
fc_op->Op()->SetAttr("in_num_col_dims", 1);
insert_remove_padding_op(fc_input, fc_op);
insert_recover_padding_op(fc_op, fc_op->outputs[0]);
found_subgraph_count++; found_subgraph_count++;
}; };
gpd3(graph, handler3); gpd3(graph, handler3);
...@@ -280,6 +373,31 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -280,6 +373,31 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(activation_op, activation_op, activation); GET_IR_NODE_FROM_SUBGRAPH(activation_op, activation_op, activation);
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, activation); GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, activation);
std::vector<int64_t> activation_input_shape =
activation_input->Var()->GetShape();
if ((activation_input_shape.size() !=
multihead_matmul_input_shape.size()) ||
(activation_input_shape.size() != 3)) {
check_flag = false;
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
if (activation_input_shape[0] != multihead_matmul_input_shape[0]) {
check_flag = false;
}
if (activation_input_shape[1] != multihead_matmul_input_shape[1]) {
check_flag = false;
}
if ((activation_input_shape[2] != multihead_matmul_input_shape[2]) &&
(activation_input_shape[2] != 4 * multihead_matmul_input_shape[2])) {
check_flag = false;
}
if (!check_flag) {
VLOG(3) << "Transformer model remove_padding shape check failed, return "
"remove_padding pass.";
return;
}
insert_remove_padding_op(activation_input, activation_op); insert_remove_padding_op(activation_input, activation_op);
insert_recover_padding_op(activation_op, activation_out); insert_recover_padding_op(activation_op, activation_out);
......
...@@ -32,6 +32,14 @@ namespace paddle { ...@@ -32,6 +32,14 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace patterns { namespace patterns {
struct EmbEltwiseLayernorm : public PatternBase {
EmbEltwiseLayernorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "emb_elt_layernorm") {}
void operator()();
PATTERN_DECL_NODE(emb_elt_layernorm_op);
PATTERN_DECL_NODE(emb_elt_layernorm_out);
};
struct SkipLayernorm : public PatternBase { struct SkipLayernorm : public PatternBase {
SkipLayernorm(PDPattern *pattern, const std::string &name_scope) SkipLayernorm(PDPattern *pattern, const std::string &name_scope)
......
...@@ -21,129 +21,134 @@ ...@@ -21,129 +21,134 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
SetTransformerInputConvertPass::SetTransformerInputConvertPass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.End();
}
namespace patterns { namespace patterns {
void SetTransformerInputConvert::operator()() { void SetTransformerInputConvert::operator()(const std::string &pos_id) {
std::unordered_set<std::string> lookup_table_ops{"lookup_table", std::unordered_set<std::string> lookup_table_ops{"lookup_table",
"lookup_table_v2"}; "lookup_table_v2"};
// Create nodes for lookup_table1 op. // Create nodes for lookup_table.
auto *lookup_table1_x = pattern->NewNode(lookup_table1_x_repr()) auto *lookup_table_id =
->assert_is_ops_input(lookup_table_ops, "Ids"); pattern->NewNode(lookup_table_id_repr())
auto *lookup_table1_w = pattern->NewNode(lookup_table1_w_repr()) ->assert_is_ops_input(lookup_table_ops, "Ids")
->assert_is_ops_input(lookup_table_ops, "W"); ->assert_more([&](Node *node) { return node->Name() == pos_id; });
auto *lookup_table1_op = auto *lookup_table_op =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(lookup_table_ops); pattern->NewNode(lookup_table_repr())->assert_is_ops(lookup_table_ops);
auto *lookup_table1_out = pattern->NewNode(lookup_table1_out_repr())
->assert_is_ops_output(lookup_table_ops)
->AsIntermediate()
->assert_is_op_input("elementwise_add", "X");
// Create nodes for lookup_table2 op.
auto *lookup_table2_x = pattern->NewNode(lookup_table2_x_repr())
->assert_is_ops_input(lookup_table_ops, "Ids");
auto *lookup_table2_w = pattern->NewNode(lookup_table2_w_repr())
->assert_is_ops_input(lookup_table_ops, "W");
auto *lookup_table2_op =
pattern->NewNode(lookup_table2_repr())->assert_is_ops(lookup_table_ops);
auto *lookup_table2_out = pattern->NewNode(lookup_table2_out_repr())
->assert_is_ops_output(lookup_table_ops)
->AsIntermediate()
->assert_is_op_input("elementwise_add", "Y");
// Create nodes for elementwise_add op.
auto *elementwise_op =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
auto *elementwise_out = pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_only_output_of_op("elementwise_add");
// links nodes. // links nodes.
lookup_table1_op->LinksFrom({lookup_table1_x, lookup_table1_w}) lookup_table_op->LinksFrom({lookup_table_id});
.LinksTo({lookup_table1_out});
lookup_table2_op->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out});
elementwise_op->LinksFrom({lookup_table1_out, lookup_table2_out})
.LinksTo({elementwise_out});
} }
void MultiheadMatmulOP::operator()() {
// Create nodes for multihead_matmul op.
auto *multihead_matmul = pattern->NewNode(multihead_matmul_repr())
->assert_is_op("multihead_matmul");
auto *multihead_matmul_out =
pattern->NewNode(multihead_matmul_out_repr())
->assert_is_op_output("multihead_matmul", "Out");
// links nodes.
multihead_matmul_out->LinksFrom({multihead_matmul});
}
} // namespace patterns } // namespace patterns
void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const { void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const {
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
if (!(graph->Has(framework::ir::kMultiheadMatmulPass) && with_dynamic_shape &&
(pos_id != ""))) {
VLOG(3) << "Transformer model need MultiheadMatmul, and "
"with_dynamic_shape. Stop this pass, "
"please reconfig.";
return;
}
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
int found_subgraph_count = 0; int found_subgraph_count = 0;
Node *transformer_input_convert_out0_node;
GraphPatternDetector gpd; Node *transformer_input_convert_out1_node;
GraphPatternDetector gpd0;
patterns::SetTransformerInputConvert fused_pattern( patterns::SetTransformerInputConvert fused_pattern(
gpd.mutable_pattern(), "transformer_input_convert_pass"); gpd0.mutable_pattern(), "transformer_input_convert_pass");
fused_pattern(); fused_pattern(pos_id);
auto handler0 = [&](const GraphPatternDetector::subgraph_t &subgraph,
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *graph) {
Graph *graph) { VLOG(3)
if (!IsCompat(subgraph, graph)) { << "transformer_input_convert_pass for pos_id, max_seqlen, mask_tensor";
LOG(WARNING) << "transformer_input_convert_pass in op compat failed."; GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, fused_pattern);
return; GET_IR_NODE_FROM_SUBGRAPH(lookup_table_id, lookup_table_id, fused_pattern);
}
VLOG(3) << "transformer_input_convert_pass for pos_id, max_seqlen";
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, fused_pattern);
// create op, var in graph // create op, var in graph
OpDesc new_desc; OpDesc new_desc(lookup_table->Op()->Block());
new_desc.SetType("transformer_input_convert"); new_desc.SetType("transformer_input_convert");
// inputs // inputs
new_desc.SetInput("X", {lookup_table2_x->Name()}); new_desc.SetInput("Input", {lookup_table_id->Name()});
// outputs // outputs
std::vector<std::string> output_0 = {"pos_id_tensor"};
std::vector<std::string> output_1 = {"max_seqlen_tensor"};
new_desc.SetOutput("PosId", output_0);
new_desc.SetOutput("MaxSeqlen", output_1);
std::string transformer_input_convert_out0_name = "pos_id_tensor"; std::string transformer_input_convert_out0_name = "pos_id_tensor";
std::string transformer_input_convert_out1_name = "max_seqlen_tensor"; std::string transformer_input_convert_out1_name = "max_seqlen_tensor";
VarDesc transformer_input_convert_out0(transformer_input_convert_out0_name); std::string transformer_input_convert_out2_name = "mask_tensor";
VarDesc transformer_input_convert_out1(transformer_input_convert_out1_name); std::vector<std::string> output_0 = {transformer_input_convert_out0_name};
transformer_input_convert_out0.SetDataType(proto::VarType::INT32); std::vector<std::string> output_1 = {transformer_input_convert_out1_name};
transformer_input_convert_out1.SetDataType(proto::VarType::INT32); std::vector<std::string> output_2 = {transformer_input_convert_out2_name};
transformer_input_convert_out0.SetShape({-1}); new_desc.SetOutput("PosId", output_0);
transformer_input_convert_out1.SetShape({-1}); new_desc.SetOutput("MaxSeqlen", output_1);
transformer_input_convert_out0.SetPersistable(false); new_desc.SetOutput("MaskTensor", output_2);
transformer_input_convert_out1.SetPersistable(false);
auto *transformer_input_convert_out0 =
lookup_table->Op()->Block()->Var(transformer_input_convert_out0_name);
auto *transformer_input_convert_out1 =
lookup_table->Op()->Block()->Var(transformer_input_convert_out1_name);
auto *transformer_input_convert_out2 =
lookup_table->Op()->Block()->Var(transformer_input_convert_out2_name);
transformer_input_convert_out0->SetDataType(proto::VarType::INT32);
transformer_input_convert_out1->SetDataType(proto::VarType::INT32);
transformer_input_convert_out2->SetDataType(proto::VarType::INT32);
transformer_input_convert_out0->SetShape({-1});
transformer_input_convert_out1->SetShape({-1});
transformer_input_convert_out2->SetShape({-1});
transformer_input_convert_out0->SetPersistable(false);
transformer_input_convert_out1->SetPersistable(false);
transformer_input_convert_out2->SetPersistable(false);
auto new_op_node = graph->CreateOpNode(&new_desc); auto new_op_node = graph->CreateOpNode(&new_desc);
auto transformer_input_convert_out0_node = auto transformer_input_convert_out0_node =
graph->CreateVarNode(&transformer_input_convert_out0); graph->CreateVarNode(transformer_input_convert_out0);
auto transformer_input_convert_out1_node = auto transformer_input_convert_out1_node =
graph->CreateVarNode(&transformer_input_convert_out1); graph->CreateVarNode(transformer_input_convert_out1);
auto transformer_input_convert_out2_node =
graph->CreateVarNode(transformer_input_convert_out2);
// needn't create variable in scope // needn't create variable in scope
IR_NODE_LINK_TO(lookup_table2_x, new_op_node); IR_NODE_LINK_TO(lookup_table_id, new_op_node);
IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out0_node); IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out0_node);
IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out1_node); IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out1_node);
IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out2_node);
found_subgraph_count++; };
gpd0(graph, handler0);
GraphPatternDetector gpd1;
patterns::MultiheadMatmulOP multihead_matmul_pattern(
gpd1.mutable_pattern(), "transformer_input_convert_pass");
multihead_matmul_pattern();
auto handler1 = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
VLOG(3) << "link pos_id, max_seqlen to multihead_matmul.";
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul,
multihead_matmul_pattern);
IR_NODE_LINK_TO(transformer_input_convert_out0_node, multihead_matmul);
IR_NODE_LINK_TO(transformer_input_convert_out1_node, multihead_matmul);
}; };
gpd1(graph, handler1);
gpd(graph, handler); found_subgraph_count++;
AddStatis(found_subgraph_count); AddStatis(found_subgraph_count);
} }
...@@ -153,9 +158,3 @@ void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const { ...@@ -153,9 +158,3 @@ void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const {
REGISTER_PASS(set_transformer_input_convert_pass, REGISTER_PASS(set_transformer_input_convert_pass,
paddle::framework::ir::SetTransformerInputConvertPass); paddle::framework::ir::SetTransformerInputConvertPass);
REGISTER_PASS_CAPABILITY(set_transformer_input_convert_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("lookup_table", 1)
.LE("lookup_table_v2", 1)
.LE("elementweise_add", 1));
...@@ -33,41 +33,36 @@ namespace framework { ...@@ -33,41 +33,36 @@ namespace framework {
namespace ir { namespace ir {
namespace patterns { namespace patterns {
// in_var emb in_var emb // in_var emb
// | | | | // | |
// lookup_table lookup_table // lookup_table
// | | // |
// lkt_var lkt_var // lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
// //
struct SetTransformerInputConvert : public PatternBase { struct SetTransformerInputConvert : public PatternBase {
SetTransformerInputConvert(PDPattern *pattern, const std::string &name_scope) SetTransformerInputConvert(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "transformer_input_convert") {} : PatternBase(pattern, name_scope, "transformer_input_convert_pass") {}
void operator()(const std::string &pos_id);
// declare operator node's name
PATTERN_DECL_NODE(lookup_table);
// declare variable node's name
PATTERN_DECL_NODE(lookup_table_id);
};
struct MultiheadMatmulOP : public PatternBase {
MultiheadMatmulOP(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "transformer_input_convert_pass") {}
void operator()(); void operator()();
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(lookup_table1); PATTERN_DECL_NODE(multihead_matmul);
PATTERN_DECL_NODE(lookup_table2); PATTERN_DECL_NODE(multihead_matmul_out);
PATTERN_DECL_NODE(elementwise);
// declare variable node's name
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table2_w);
PATTERN_DECL_NODE(lookup_table2_out);
PATTERN_DECL_NODE(elementwise_out);
}; };
} // namespace patterns } // namespace patterns
class SetTransformerInputConvertPass : public FusePassBase { class SetTransformerInputConvertPass : public FusePassBase {
public: public:
SetTransformerInputConvertPass(); SetTransformerInputConvertPass() {}
virtual ~SetTransformerInputConvertPass() {} virtual ~SetTransformerInputConvertPass() {}
protected: protected:
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
// detect start pattern.
//
// in_var emb in_var emb
// | | | |
// lookup_table lookup_table
// | |
// lkt_var lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct TrtEmbedding2Eltwise1Pattern : public PatternBase {
TrtEmbedding2Eltwise1Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "embedding2_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(feed1);
PATTERN_DECL_NODE(feed2);
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table2_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table2);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(lookup_table2_out);
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
};
// detect repeats inner pattern
//
// elt_out_var in_var emb
// \ | |
// \ lookup_table
// \ |
// \ lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct TrtEmbedding1Eltwise1Pattern : public PatternBase {
TrtEmbedding1Eltwise1Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "embedding1_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(feed1);
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(eltwise_add_in);
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
};
// detect end pattern
//
// elementwise_add
// |
// elt_out_var
// scale | bias
// \ | /
// layer_norm
//
struct TrtSkipLayerNorm : public PatternBase {
TrtSkipLayerNorm(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "skip_layernorm") {}
void operator()();
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
// Delete the mean and var nodes in the graph.
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
};
} // namespace patterns
// The TrtEmbeddingEltwiseLayerNormFusePass detect the following pattern:
//
// inputs operator output
// --------------------------------------------------------------------
// (word, weights_0) lookup_table -> word_emb
// (pos, weights_1) lookup_table -> pos_emb
// (sent, weights_2) lookup_table -> sent_emb
// (word_emb, pos_emb) elementweise_add -> elementwise_out_0
// (elemtwise_out_0, sent_emb) elementweise_add -> elementwise_out_1
// (elementwise_out_1, scale, bias) layer_norm -> layer_norm_out
//
// and then convert the corresponding subgraph to:
//
// (word, pos, sent, weights_0, weights_1, weights_2,
// scale, baias) embedding_eltwise_layernorm -> layer_norm_out
//
//
// in_var emb_var in_var emb_var in_var emb_var in_var emb_var
// | | | | | | | |
// lookup_table lookup_table lookup_table ... lookup_table
// | | | |
// lkt_var lkt_var lkt_var lkt_var
// \ / | ... |
// elementwise_add | |
// \ / |
// elementwise_add |
// | |
// elt_var /
// \ /
// elementwise_add
// |
// layer_norm
class TrtEmbeddingEltwiseLayerNormFusePass : public FusePassBase {
public:
TrtEmbeddingEltwiseLayerNormFusePass();
virtual ~TrtEmbeddingEltwiseLayerNormFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
int BuildFusion(Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const;
const std::string name_scope_{"trt_embedding_eltwise_layernorm_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct TrtMultiHeadMatmulPattern : public PatternBase {
TrtMultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(scale);
PATTERN_DECL_NODE(scale_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
struct TrtMultiHeadMatmulV3Pattern : public PatternBase {
TrtMultiHeadMatmulV3Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul_v3") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns
class TrtMultiHeadMatmulFusePass : public FusePassBase {
public:
virtual ~TrtMultiHeadMatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"trt_multihead_matmul_fuse"};
};
class TrtMultiHeadMatmulV2FusePass : public FusePassBase {
public:
TrtMultiHeadMatmulV2FusePass();
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"trt_multihead_matmul_fuse_v2"};
private:
int BuildFusionV2(Graph* graph, const std::string& name_scope,
Scope* scope) const;
};
class TrtMultiHeadMatmulV3FusePass : public FusePassBase {
public:
TrtMultiHeadMatmulV3FusePass();
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"trt_multihead_matmul_fuse_v3"};
private:
int BuildFusionV3(Graph* graph, const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct TrtSkipLayerNorm : public PatternBase {
TrtSkipLayerNorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "skip_layernorm") {}
PDNode *operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(layer_norm);
// declare variable node's name
PATTERN_DECL_NODE(
elementwise_out); // (elementwise_input_x,elementwise_input_y) ->
// elementwise_out
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
};
PDNode *TrtSkipLayerNorm::operator()(PDNode *x, PDNode *y) {
// Create nodes for elementwise add op.
x->assert_is_op_input("elementwise_add", "X");
y->assert_is_op_input("elementwise_add", "Y");
auto *elementwise =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
auto *elementwise_out_var =
pattern->NewNode(elementwise_out_repr())
->AsOutput()
->assert_is_only_output_of_op("elementwise_add");
// Add links for elementwise_add op.
elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var});
// Create nodes for layer_norm op.
elementwise_out_var->AsIntermediate()->assert_is_op_input("layer_norm");
auto *layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Y");
auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto *layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
// Add links for layer_norm op.
layer_norm
->LinksFrom(
{elementwise_out_var, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
return layer_norm_out_var;
}
} // namespace patterns
void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("skip_layernorm_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("skip_layernorm_fuse/x")
->AsInput()
->assert_is_op_input("elementwise_add", "X")
->assert_var_not_persistable();
auto *y = gpd.mutable_pattern()
->NewNode("skip_layernorm_fuse/y")
->AsInput()
->assert_is_op_input("elementwise_add", "Y")
->assert_var_not_persistable();
patterns::TrtSkipLayerNorm fused_pattern(gpd.mutable_pattern(),
"skip_layernorm_fuse");
fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "skip_layernorm pass in op compat failed.";
return;
}
VLOG(4) << "handle TrtSkipLayerNorm fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
fused_pattern);
std::unordered_set<const Node *> del_node_set;
// Create an TrtSkipLayerNorm op node
OpDesc new_desc(elementwise->Op()->Block());
new_desc.SetType("skip_layernorm");
// inputs
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_desc.SetInput("Bias", {layer_norm_bias->Name()});
if (layer_norm->Op()->HasAttr("out_threshold")) {
new_desc.SetAttr("enable_int8", true);
new_desc.SetAttr("out_threshold",
layer_norm->Op()->GetAttr("out_threshold"));
}
// outputs
new_desc.SetOutput("Out", {layer_norm_out->Name()});
// attrs
new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon"));
new_desc.SetAttr("begin_norm_axis",
layer_norm->Op()->GetAttr("begin_norm_axis"));
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(elementwise);
del_node_set.insert(layer_norm);
del_node_set.insert(elementwise_out);
del_node_set.insert(layer_norm_mean);
del_node_set.insert(layer_norm_variance);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(layer_norm_scale, fused_node);
IR_NODE_LINK_TO(layer_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, layer_norm_out);
found_subgraph_count++;
};
gpd(graph, handler);
if (found_subgraph_count > 0) {
bool use_varseqlen = Get<bool>("use_varseqlen");
std::string pos_id = Get<std::string>("tensorrt_transformer_posid");
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if (use_varseqlen && pos_id != "" && mask_id != "") {
if (graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)) {
VLOG(3) << "start varseqlen trt_skip_layernorm_fuse_pass";
} else {
PADDLE_THROW(platform::errors::Fatal(
"Use transformer'varseqlen need "
"embedding_eltwise_layernorm_fuse_pass. please use no_varseqlen"));
}
} else if (!use_varseqlen && pos_id == "" && mask_id == "") {
VLOG(3) << "start no_varseqlen trt_skip_layernorm_fuse_pass";
} else {
PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please "
"reconfig"));
}
}
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(trt_skip_layernorm_fuse_pass,
paddle::framework::ir::TrtSkipLayerNormFusePass);
REGISTER_PASS_CAPABILITY(trt_skip_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("layer_norm", 0));
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
// | | | |
// other_op1 other_op2 other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> skip_layernorm
// | |
// layer_norm other_op3
// | |
// other_op3
// |
class Graph;
class TrtSkipLayerNormFusePass : public FusePassBase {
public:
TrtSkipLayerNormFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1})
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}
virtual ~TrtSkipLayerNormFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -216,8 +216,12 @@ struct Argument { ...@@ -216,8 +216,12 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine, DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine,
bool); bool);
DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool); DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool);
DECL_ARGUMENT_FIELD(tensorrt_use_oss, TensorRtUseOSS, bool); DECL_ARGUMENT_FIELD(tensorrt_use_varseqlen, TensorRtUseOSS, bool);
DECL_ARGUMENT_FIELD(tensorrt_with_interleaved, TensorRtWithInterleaved, bool); DECL_ARGUMENT_FIELD(tensorrt_with_interleaved, TensorRtWithInterleaved, bool);
DECL_ARGUMENT_FIELD(tensorrt_transformer_posid, TensorRtTransformerPosid,
std::string);
DECL_ARGUMENT_FIELD(tensorrt_transformer_maskid, TensorRtTransformerMaskid,
std::string);
DECL_ARGUMENT_FIELD(tensorrt_shape_range_info_path, DECL_ARGUMENT_FIELD(tensorrt_shape_range_info_path,
TensorRtShapeRangeInfoPath, std::string); TensorRtShapeRangeInfoPath, std::string);
DECL_ARGUMENT_FIELD(tensorrt_tuned_dynamic_shape, TensorRtTunedDynamicShape, DECL_ARGUMENT_FIELD(tensorrt_tuned_dynamic_shape, TensorRtTunedDynamicShape,
......
...@@ -55,9 +55,13 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -55,9 +55,13 @@ void IRPassManager::CreatePasses(Argument *argument,
int pass_num = 0; int pass_num = 0;
for (const std::string &pass_name : passes) { for (const std::string &pass_name : passes) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
pass->Set("use_oss", new bool(argument->tensorrt_use_oss())); pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen()));
pass->Set("with_interleaved", pass->Set("with_interleaved",
new bool(argument->tensorrt_with_interleaved())); new bool(argument->tensorrt_with_interleaved()));
pass->Set("tensorrt_transformer_posid",
new std::string(argument->tensorrt_transformer_posid()));
pass->Set("tensorrt_transformer_maskid",
new std::string(argument->tensorrt_transformer_maskid()));
pass->Set("disable_logs", new bool(argument->disable_logs())); pass->Set("disable_logs", new bool(argument->disable_logs()));
auto precision_mode = argument->tensorrt_precision_mode(); auto precision_mode = argument->tensorrt_precision_mode();
bool enable_int8 = precision_mode == AnalysisConfig::Precision::kInt8; bool enable_int8 = precision_mode == AnalysisConfig::Precision::kInt8;
......
...@@ -377,12 +377,18 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -377,12 +377,18 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
Get<int>("workspace_size"), precision_mode, calibrator.get(), Get<int>("workspace_size"), precision_mode, calibrator.get(),
Get<int>("gpu_device_id"), min_input_shape, max_input_shape, Get<int>("gpu_device_id"), min_input_shape, max_input_shape,
opt_input_shape, disable_trt_plugin_fp16); opt_input_shape, disable_trt_plugin_fp16);
trt_engine->SetUseOSS(Get<bool>("use_oss")); trt_engine->SetUseOSS(Get<bool>("use_varseqlen"));
trt_engine->SetWithInterleaved(Get<bool>("with_interleaved")); trt_engine->SetWithInterleaved(Get<bool>("with_interleaved"));
trt_engine->SetTransformerPosid(
Get<std::string>("tensorrt_transformer_posid"));
trt_engine->SetTransformerMaskid(
Get<std::string>("tensorrt_transformer_maskid"));
trt_engine->SetUseDLA(Get<bool>("trt_use_dla")); trt_engine->SetUseDLA(Get<bool>("trt_use_dla"));
trt_engine->SetDLACore(Get<int>("trt_dla_core")); trt_engine->SetDLACore(Get<int>("trt_dla_core"));
trt_engine->SetUseInspector(Get<bool>("use_inspector")); trt_engine->SetUseInspector(Get<bool>("use_inspector"));
trt_engine->SetWithErnie(graph->Has(framework::ir::kMultiheadMatmulPass)); trt_engine->SetWithErnie(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass));
if (use_static_engine) { if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData( trt_engine_serialized_data = GetTrtEngineSerializedData(
......
...@@ -256,8 +256,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -256,8 +256,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(trt_dla_core_); CP_MEMBER(trt_dla_core_);
CP_MEMBER(trt_use_static_engine_); CP_MEMBER(trt_use_static_engine_);
CP_MEMBER(trt_use_calib_mode_); CP_MEMBER(trt_use_calib_mode_);
CP_MEMBER(trt_use_oss_); CP_MEMBER(trt_use_varseqlen_);
CP_MEMBER(trt_with_interleaved_); CP_MEMBER(trt_with_interleaved_);
CP_MEMBER(tensorrt_transformer_posid_);
CP_MEMBER(tensorrt_transformer_maskid_);
CP_MEMBER(trt_tuned_dynamic_shape_); CP_MEMBER(trt_tuned_dynamic_shape_);
CP_MEMBER(trt_allow_build_at_runtime_); CP_MEMBER(trt_allow_build_at_runtime_);
CP_MEMBER(collect_shape_range_info_); CP_MEMBER(collect_shape_range_info_);
...@@ -546,7 +548,7 @@ void AnalysisConfig::Exp_DisableTensorRtOPs( ...@@ -546,7 +548,7 @@ void AnalysisConfig::Exp_DisableTensorRtOPs(
trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end()); trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
} }
void AnalysisConfig::EnableTensorRtOSS() { trt_use_oss_ = true; } void AnalysisConfig::EnableVarseqlen() { trt_use_varseqlen_ = true; }
// TODO(Superjomn) refactor this, buggy. // TODO(Superjomn) refactor this, buggy.
void AnalysisConfig::Update() { void AnalysisConfig::Update() {
...@@ -1034,9 +1036,13 @@ std::string AnalysisConfig::Summary() { ...@@ -1034,9 +1036,13 @@ std::string AnalysisConfig::Summary() {
? shape_range_info_path_ ? shape_range_info_path_
: "false"}); : "false"});
os.InsertRow({"tensorrt_use_oss", trt_use_oss_ ? "true" : "false"}); os.InsertRow(
{"tensorrt_use_varseqlen", trt_use_varseqlen_ ? "true" : "false"});
os.InsertRow({"tensorrt_with_interleaved", os.InsertRow({"tensorrt_with_interleaved",
trt_with_interleaved_ ? "true" : "false"}); trt_with_interleaved_ ? "true" : "false"});
os.InsertRow({"tensorrt_transformer_posid", tensorrt_transformer_posid_});
os.InsertRow(
{"tensorrt_transformer_maskid", tensorrt_transformer_maskid_});
os.InsertRow({"tensorrt_use_dla", trt_use_dla_ ? "true" : "false"}); os.InsertRow({"tensorrt_use_dla", trt_use_dla_ ? "true" : "false"});
if (trt_use_dla_) { if (trt_use_dla_) {
os.InsertRow({"tensorrt_dla_core", std::to_string(trt_dla_core_)}); os.InsertRow({"tensorrt_dla_core", std::to_string(trt_dla_core_)});
......
...@@ -853,8 +853,10 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -853,8 +853,10 @@ void AnalysisPredictor::PrepareArgument() {
} }
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseOSS(config_.trt_use_oss_); argument_.SetTensorRtUseOSS(config_.trt_use_varseqlen_);
argument_.SetTensorRtWithInterleaved(config_.trt_with_interleaved_); argument_.SetTensorRtWithInterleaved(config_.trt_with_interleaved_);
argument_.SetTensorRtTransformerPosid(config_.tensorrt_transformer_posid_);
argument_.SetTensorRtTransformerMaskid(config_.tensorrt_transformer_maskid_);
argument_.SetMinInputShape(config_.min_input_shape_); argument_.SetMinInputShape(config_.min_input_shape_);
argument_.SetMaxInputShape(config_.max_input_shape_); argument_.SetMaxInputShape(config_.max_input_shape_);
argument_.SetOptimInputShape(config_.optim_input_shape_); argument_.SetOptimInputShape(config_.optim_input_shape_);
...@@ -1803,6 +1805,9 @@ USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) ...@@ -1803,6 +1805,9 @@ USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
USE_TRT_CONVERTER(preln_skip_layernorm) USE_TRT_CONVERTER(preln_skip_layernorm)
USE_TRT_CONVERTER(roll) USE_TRT_CONVERTER(roll)
USE_TRT_CONVERTER(strided_slice) USE_TRT_CONVERTER(strided_slice)
USE_TRT_CONVERTER(transformer_input_convert)
USE_TRT_CONVERTER(recover_padding)
USE_TRT_CONVERTER(remove_padding)
#endif #endif
namespace paddle_infer { namespace paddle_infer {
...@@ -1971,6 +1976,20 @@ void InternalUtils::UpdateConfigInterleaved(paddle_infer::Config *c, ...@@ -1971,6 +1976,20 @@ void InternalUtils::UpdateConfigInterleaved(paddle_infer::Config *c,
#endif #endif
} }
void InternalUtils::SetTransformerPosid(
paddle_infer::Config *c, const std::string &tensorrt_transformer_posid) {
#ifdef PADDLE_WITH_CUDA
c->tensorrt_transformer_posid_ = tensorrt_transformer_posid;
#endif
}
void InternalUtils::SetTransformerMaskid(
paddle_infer::Config *c, const std::string &tensorrt_transformer_maskid) {
#ifdef PADDLE_WITH_CUDA
c->tensorrt_transformer_maskid_ = tensorrt_transformer_maskid;
#endif
}
void InternalUtils::SyncStream(paddle_infer::Predictor *p) { void InternalUtils::SyncStream(paddle_infer::Predictor *p) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *pred = dynamic_cast<paddle::AnalysisPredictor *>(p->predictor_.get()); auto *pred = dynamic_cast<paddle::AnalysisPredictor *>(p->predictor_.get());
......
...@@ -618,14 +618,14 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -618,14 +618,14 @@ struct PD_INFER_DECL AnalysisConfig {
/// may be more high-performance. Libnvinfer_plugin.so greater than /// may be more high-performance. Libnvinfer_plugin.so greater than
/// V7.2.1 is needed. /// V7.2.1 is needed.
/// ///
void EnableTensorRtOSS(); void EnableVarseqlen();
/// ///
/// \brief A boolean state telling whether to use the TensorRT OSS. /// \brief A boolean state telling whether to use the TensorRT OSS.
/// ///
/// \return bool Whether to use the TensorRT OSS. /// \return bool Whether to use the TensorRT OSS.
/// ///
bool tensorrt_oss_enabled() { return trt_use_oss_; } bool tensorrt_varseqlen_enabled() { return trt_use_varseqlen_; }
/// ///
/// \brief Enable TensorRT DLA /// \brief Enable TensorRT DLA
...@@ -954,8 +954,10 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -954,8 +954,10 @@ struct PD_INFER_DECL AnalysisConfig {
Precision tensorrt_precision_mode_{Precision::kFloat32}; Precision tensorrt_precision_mode_{Precision::kFloat32};
bool trt_use_static_engine_{false}; bool trt_use_static_engine_{false};
bool trt_use_calib_mode_{true}; bool trt_use_calib_mode_{true};
bool trt_use_oss_{false}; bool trt_use_varseqlen_{false};
bool trt_with_interleaved_{false}; bool trt_with_interleaved_{false};
std::string tensorrt_transformer_posid_{""};
std::string tensorrt_transformer_maskid_{""};
bool trt_use_dla_{false}; bool trt_use_dla_{false};
int trt_dla_core_{0}; int trt_dla_core_{0};
std::map<std::string, std::vector<int>> min_input_shape_{}; std::map<std::string, std::vector<int>> min_input_shape_{};
......
...@@ -435,6 +435,12 @@ class PD_INFER_DECL InternalUtils { ...@@ -435,6 +435,12 @@ class PD_INFER_DECL InternalUtils {
static void UpdateConfigInterleaved(paddle_infer::Config* c, static void UpdateConfigInterleaved(paddle_infer::Config* c,
bool with_interleaved); bool with_interleaved);
static void SetTransformerPosid(
paddle_infer::Config* c, const std::string& tensorrt_transformer_posid);
static void SetTransformerMaskid(
paddle_infer::Config* c, const std::string& tensorrt_transformer_maskid);
static void SyncStream(paddle_infer::Predictor* pred); static void SyncStream(paddle_infer::Predictor* pred);
static void SyncStream(cudaStream_t stream); static void SyncStream(cudaStream_t stream);
template <typename T> template <typename T>
......
...@@ -94,25 +94,25 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -94,25 +94,25 @@ const std::vector<std::string> kTRTSubgraphPasses({
"add_support_int8_pass", // "add_support_int8_pass", //
// "fc_fuse_pass", // // "fc_fuse_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", // "trt_embedding_eltwise_layernorm_fuse_pass", //
"preln_embedding_eltwise_layernorm_fuse_pass", // "preln_embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", // "trt_multihead_matmul_fuse_pass_v2", //
"multihead_matmul_fuse_pass_v3", // "trt_multihead_matmul_fuse_pass_v3", //
"skip_layernorm_fuse_pass", // "trt_skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", // "preln_skip_layernorm_fuse_pass", //
// "set_transformer_input_convert_pass", // // "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", //
"trt_squeeze2_matmul_fuse_pass", // "trt_squeeze2_matmul_fuse_pass", //
"trt_reshape2_matmul_fuse_pass", // "trt_reshape2_matmul_fuse_pass", //
"trt_flatten2_matmul_fuse_pass", // "trt_flatten2_matmul_fuse_pass", //
"trt_map_matmul_v2_to_mul_pass", // "trt_map_matmul_v2_to_mul_pass", //
"trt_map_matmul_v2_to_matmul_pass", // "trt_map_matmul_v2_to_matmul_pass", //
"trt_map_matmul_to_mul_pass", // "trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
// "remove_padding_recover_padding_pass", // "remove_padding_recover_padding_pass", //
// "delete_remove_padding_recover_padding_pass", // "delete_remove_padding_recover_padding_pass", //
// "yolo_box_fuse_pass", // // "yolo_box_fuse_pass", //
"tensorrt_subgraph_pass", // "tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
......
...@@ -303,13 +303,13 @@ void PD_ConfigDisableTensorRtOPs(__pd_keep PD_Config* pd_config, size_t ops_num, ...@@ -303,13 +303,13 @@ void PD_ConfigDisableTensorRtOPs(__pd_keep PD_Config* pd_config, size_t ops_num,
config->Exp_DisableTensorRtOPs(ops_list); config->Exp_DisableTensorRtOPs(ops_list);
} }
void PD_ConfigEnableTensorRtOSS(__pd_keep PD_Config* pd_config) { void PD_ConfigEnableVarseqlen(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG; CHECK_AND_CONVERT_PD_CONFIG;
config->EnableTensorRtOSS(); config->EnableVarseqlen();
} }
PD_Bool PD_ConfigTensorRtOssEnabled(__pd_keep PD_Config* pd_config) { PD_Bool PD_ConfigTensorRtOssEnabled(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG; CHECK_AND_CONVERT_PD_CONFIG;
return config->tensorrt_oss_enabled(); return config->tensorrt_varseqlen_enabled();
} }
void PD_ConfigEnableTensorRtDla(__pd_keep PD_Config* pd_config, void PD_ConfigEnableTensorRtDla(__pd_keep PD_Config* pd_config,
......
...@@ -432,7 +432,7 @@ PADDLE_CAPI_EXPORT extern void PD_ConfigDisableTensorRtOPs( ...@@ -432,7 +432,7 @@ PADDLE_CAPI_EXPORT extern void PD_ConfigDisableTensorRtOPs(
/// ///
/// \param[in] pd_onfig config /// \param[in] pd_onfig config
/// ///
PADDLE_CAPI_EXPORT extern void PD_ConfigEnableTensorRtOSS( PADDLE_CAPI_EXPORT extern void PD_ConfigEnableVarseqlen(
__pd_keep PD_Config* pd_config); __pd_keep PD_Config* pd_config);
/// ///
/// \brief A boolean state telling whether to use the TensorRT OSS. /// \brief A boolean state telling whether to use the TensorRT OSS.
......
...@@ -500,8 +500,8 @@ func (config *Config) DisableTensorRtOPs(ops []string) { ...@@ -500,8 +500,8 @@ func (config *Config) DisableTensorRtOPs(ops []string) {
/// may be more high-performance. Libnvinfer_plugin.so greater than /// may be more high-performance. Libnvinfer_plugin.so greater than
/// V7.2.1 is needed. /// V7.2.1 is needed.
/// ///
func (config *Config) EnableTensorRtOSS() { func (config *Config) EnableVarseqlen() {
C.PD_ConfigEnableTensorRtOSS(config.c) C.PD_ConfigEnableVarseqlen(config.c)
} }
/// ///
......
...@@ -54,7 +54,7 @@ func TestNewConfig(t *testing.T) { ...@@ -54,7 +54,7 @@ func TestNewConfig(t *testing.T) {
} }
config.SetTRTDynamicShapeInfo(minInputShape, maxInputShape, optInputShape, false) config.SetTRTDynamicShapeInfo(minInputShape, maxInputShape, optInputShape, false)
config.EnableTensorRtOSS() config.EnableVarseqlen()
t.Logf("TensorrtOssEnabled:%+v", config.TensorrtOssEnabled()) t.Logf("TensorrtOssEnabled:%+v", config.TensorrtOssEnabled())
config.EnableTensorRtDLA(0) config.EnableTensorRtDLA(0)
...@@ -138,4 +138,4 @@ func TestONNXRuntime(t *testing.T) { ...@@ -138,4 +138,4 @@ func TestONNXRuntime(t *testing.T) {
config.SetCpuMathLibraryNumThreads(4) config.SetCpuMathLibraryNumThreads(4)
t.Logf("CpuMathLibraryNumThreads:%+v", config.CpuMathLibraryNumThreads()) t.Logf("CpuMathLibraryNumThreads:%+v", config.CpuMathLibraryNumThreads())
} }
\ No newline at end of file
...@@ -56,7 +56,11 @@ nv_library(tensorrt_converter ...@@ -56,7 +56,11 @@ nv_library(tensorrt_converter
strided_slice_op.cc strided_slice_op.cc
preln_skip_layernorm.cc preln_skip_layernorm.cc
roll_op.cc roll_op.cc
transformer_input_convert_op.cc
remove_padding_op.cc
recover_padding_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter) paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_engine tensorrt_converter)
...@@ -30,23 +30,28 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -30,23 +30,28 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(6000)
VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer"; VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto word_id_name = op_desc.Input("WordId").front(); auto word_id_name = op_desc.Input("WordId").front();
auto pos_id_name = op_desc.Input("PosId").front(); auto pos_id_name = engine_->tensorrt_transformer_posid();
engine_->Set("ernie_pos_name", new std::string(pos_id_name)); engine_->Set("ernie_pos_name", new std::string(pos_id_name));
auto sent_id_name = op_desc.Input("SentId").front(); auto sent_id_name = op_desc.Input("SentId").front();
auto mask_id_name = engine_->tensorrt_transformer_maskid();
auto word_emb_name = op_desc.Input("WordEmbedding").front(); auto word_emb_name = op_desc.Input("WordEmbedding").front();
auto pos_emb_name = op_desc.Input("PosEmbedding").front(); auto pos_emb_name = op_desc.Input("PosEmbedding").front();
auto sent_emb_name = op_desc.Input("SentEmbedding").front(); auto sent_emb_name = op_desc.Input("SentEmbedding").front();
std::vector<std::string> id_names; std::vector<std::string> id_names;
std::vector<std::string> emb_names; std::vector<std::string> emb_names;
bool flag_varseqlen =
engine_->use_varseqlen() && pos_id_name != "" && mask_id_name != "";
if (engine_->use_oss()) { if (flag_varseqlen) {
engine_->SetITensor("word_id", engine_->GetITensor(word_id_name));
engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name));
engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name));
id_names = id_names =
std::vector<std::string>{word_id_name, pos_id_name, sent_id_name}; std::vector<std::string>{word_id_name, pos_id_name, sent_id_name};
emb_names = emb_names =
...@@ -106,7 +111,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -106,7 +111,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8"); bool enable_int8 = op_desc.HasAttr("enable_int8");
if (engine_->use_oss()) { if (flag_varseqlen) {
int output_fp16 = static_cast<int>((engine_->WithFp16() == 1) ? 1 : 0); int output_fp16 = static_cast<int>((engine_->WithFp16() == 1) ? 1 : 0);
if (enable_int8) { if (enable_int8) {
output_fp16 = 1; output_fp16 = 1;
...@@ -121,7 +126,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -121,7 +126,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
output_fp16, 1, output_fp16, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Only Precision::KHalf(fp16) is supported when infering " "Only Precision::KHalf(fp16) is supported when infering "
"ernie(bert) model with config.EnableTensorRtOSS(). " "ernie(bert) model with config.EnableVarseqlen(). "
"But Precision::KFloat32 is setted.")); "But Precision::KFloat32 is setted."));
const std::vector<nvinfer1::PluginField> fields{ const std::vector<nvinfer1::PluginField> fields{
{"bert_embeddings_layernorm_beta", bias, {"bert_embeddings_layernorm_beta", bias,
...@@ -159,8 +164,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -159,8 +164,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
plugin_inputs.emplace_back( plugin_inputs.emplace_back(
engine_->GetITensor(pos_id_name)); // cu_seqlens, engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2 // eval_placeholder_2
auto max_seqlen_tensor = auto max_seqlen_tensor = engine_->GetITensor(mask_id_name);
engine_->GetITensor(engine_->network()->getInput(3)->getName());
auto* shuffle_layer = auto* shuffle_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor); TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor);
nvinfer1::Dims shape_dim; nvinfer1::Dims shape_dim;
...@@ -193,8 +197,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -193,8 +197,8 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
engine_->SetTensorDynamicRange(plugin_layer->getOutput(1), out_scale); engine_->SetTensorDynamicRange(plugin_layer->getOutput(1), out_scale);
} }
if (engine_->with_interleaved()) { if (engine_->with_interleaved()) {
VLOG(4) VLOG(4) << "fused emb_eltwise_layernorm op: use_varseqlen and "
<< "fused emb_eltwise_layernorm op: use_oss and with_interleaved"; "with_interleaved";
if (!enable_int8) { if (!enable_int8) {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8.")); platform::errors::Fatal("use with_interleaved must be int8."));
...@@ -229,12 +233,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -229,12 +233,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm", {output_name}, RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm", {output_name},
test_mode); test_mode);
} }
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} }
}; };
......
...@@ -250,8 +250,7 @@ class FcOpConverter : public OpConverter { ...@@ -250,8 +250,7 @@ class FcOpConverter : public OpConverter {
} }
// If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can // If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// not add Shuffle layer in ernie's multihead. // not add Shuffle layer in ernie's multihead.
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && if (x_dim.nbDims == 4 && x_num_col_dims == 1) {
x_dim.d[3] == 1 && x_num_col_dims == 2) {
if (enable_int8 || support_int8) { if (enable_int8 || support_int8) {
// add conv1x1 layer // add conv1x1 layer
nvinfer1::DimsHW nv_ksize(1, 1); nvinfer1::DimsHW nv_ksize(1, 1);
......
...@@ -76,12 +76,14 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -76,12 +76,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
bool flag_varseqlen = engine_->use_varseqlen() &&
engine_->tensorrt_transformer_posid() != "" &&
engine_->tensorrt_transformer_maskid() != "";
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
if (engine_->use_oss()) { if (flag_varseqlen) {
if (engine_->precision() == AnalysisConfig::Precision::kFloat32) { if (engine_->precision() == AnalysisConfig::Precision::kFloat32) {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"use use_oss must be int8 or half, not float32.")); "use use_varseqlen must be int8 or half, not float32."));
} }
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data), static_cast<void*>(weight_data),
...@@ -90,7 +92,8 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -90,7 +92,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
static_cast<void*>(bias_data), static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())}; static_cast<int32_t>(bias_t->numel())};
if (engine_->with_interleaved()) { if (engine_->with_interleaved()) {
VLOG(4) << "fused multihead_matmul op: use_oss and with_interleaved"; VLOG(4) << "fused multihead_matmul op: use_varseqlen and "
"with_interleaved";
if (!op_desc.HasAttr("Input_scale")) { if (!op_desc.HasAttr("Input_scale")) {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8.")); platform::errors::Fatal("use with_interleaved must be int8."));
...@@ -233,9 +236,6 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -233,9 +236,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0; BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0;
} }
} }
auto mask_tensor = engine_->GetITensor("qkv_plugin_mask");
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "2"); "CustomQKVToContextPluginDynamic", "2");
assert(creator != nullptr); assert(creator != nullptr);
...@@ -272,18 +272,10 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -272,18 +272,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0)); plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(mask_tensor); plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask"));
if (engine_->Has("ernie_pos_name")) { plugin_inputs.emplace_back(engine_->GetITensor("pos_id"));
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->Get<std::string>("ernie_pos_name"))); auto max_seqlen_tensor = engine_->GetITensor("mask_id");
} else {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()
->getInput(2)
->getName())); // cu_seqlens, eval_placeholder_2
}
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, engine_, Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor)); *const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
......
...@@ -32,7 +32,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -32,7 +32,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer"; VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer";
if (!(engine_->use_oss() && engine_->with_interleaved())) { if (!(engine_->use_varseqlen() && engine_->with_interleaved())) {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"PrelnErnie: If you want to use oss, must be with interleaved")); "PrelnErnie: If you want to use oss, must be with interleaved"));
} }
......
...@@ -24,7 +24,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter { ...@@ -24,7 +24,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
VLOG(4) << "convert fused preln_skip_layernorm op to tensorrt layer"; VLOG(4) << "convert fused preln_skip_layernorm op to tensorrt layer";
if (!(engine_->use_oss() && engine_->with_interleaved())) { if (!(engine_->use_varseqlen() && engine_->with_interleaved())) {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"PrelnErnie: If you want to use oss, must be with interleaved")); "PrelnErnie: If you want to use oss, must be with interleaved"));
} }
...@@ -60,7 +60,8 @@ class PrelnSkipLayerNormOpConverter : public OpConverter { ...@@ -60,7 +60,8 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
VLOG(4) << "fused preln_skip_layernorm op: use_oss and with_interleaved"; VLOG(4)
<< "fused preln_skip_layernorm op: use_varseqlen and with_interleaved";
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "4"); "CustomSkipLayerNormPluginDynamic", "4");
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Recover padding of transformer'input.
*/
class RecoverPadding : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "Recover padding of transformer'output: VarSeqlen -> Padding.";
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(platform::errors::Fatal(
"recover_padding_op: If you want to use transformer, must "
"be with dynamic shape"));
}
framework::OpDesc op_desc(op, nullptr);
/*
auto x_var_name = op_desc.Input(InputNames()).front();
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
*/
auto input_name = op_desc.Input("Input").front();
std::cout << "input_name: " << input_name << std::endl;
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(engine_->GetITensor(input_name));
plugin_inputs.push_back(engine_->GetITensor("pos_id"));
plugin_inputs.push_back(engine_->GetITensor("mask_id"));
int input_num = 3;
auto output_name = op_desc.Output("Out").front();
plugin::RecoverPaddingPlugin* plugin = new plugin::RecoverPaddingPlugin();
nvinfer1::ILayer* layer =
engine_->AddDynamicPlugin(plugin_inputs.data(), input_num, plugin);
RreplenishLayerAndOutput(layer, "recover_padding", {output_name},
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(recover_padding, RecoverPadding);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Remove padding of transformer'input.
*/
class RemovePadding : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "Remove padding of transformer'input: Padding -> VarSeqlen";
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(platform::errors::Fatal(
"remove_padding_op: If you want to use transformer, must "
"be with dynamic shape"));
}
framework::OpDesc op_desc(op, nullptr);
auto input_name = op_desc.Input("Input").front();
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(engine_->GetITensor(input_name));
plugin_inputs.push_back(engine_->GetITensor("pos_id"));
plugin_inputs.push_back(engine_->GetITensor("word_id"));
size_t input_num = plugin_inputs.size();
auto output_name = op_desc.Output("Out").front();
plugin::RemovePaddingPlugin* plugin = new plugin::RemovePaddingPlugin();
nvinfer1::ILayer* layer =
engine_->AddDynamicPlugin(plugin_inputs.data(), input_num, plugin);
RreplenishLayerAndOutput(layer, "remove_padding_op", {output_name},
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(remove_padding, RemovePadding);
...@@ -52,10 +52,13 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -52,10 +52,13 @@ class SkipLayerNormOpConverter : public OpConverter {
bool enable_int8 = op_desc.HasAttr("enable_int8"); bool enable_int8 = op_desc.HasAttr("enable_int8");
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
bool flag_varseqlen = engine_->use_varseqlen() &&
if (engine_->use_oss()) { engine_->tensorrt_transformer_posid() != "" &&
engine_->tensorrt_transformer_maskid() != "";
if (flag_varseqlen) {
if (engine_->with_interleaved()) { if (engine_->with_interleaved()) {
VLOG(4) << "fused skip_layernorm op: use_oss and with_interleaved"; VLOG(4)
<< "fused skip_layernorm op: use_varseqlen and with_interleaved";
if (!enable_int8) { if (!enable_int8) {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8.")); platform::errors::Fatal("use with_interleaved must be int8."));
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -74,47 +73,12 @@ class SliceOpConverter : public OpConverter { ...@@ -74,47 +73,12 @@ class SliceOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
if (engine_->use_oss() && engine_->with_ernie() && bool with_fp16 =
input_dims.nbDims == 4) { engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
std::vector<nvinfer1::ITensor*> plugin_inputs; int decrease_axis = decrease_axises.size() == 0 ? -1 : decrease_axises[0];
if (engine_->with_interleaved()) { plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic(
auto* shuffler_slice = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); starts, ends, axes, decrease_axis, with_fp16);
nvinfer1::Permutation transpose_embed{2, 1, 0, 3}; layer = engine_->AddDynamicPlugin(&input, 1, plugin);
shuffler_slice->setSecondTranspose(transpose_embed);
engine_->SetTensorDynamicRange(shuffler_slice->getOutput(0),
out_scale);
shuffler_slice->setName(
("SpecialSlice_interleaved: transpose: (Output: " + output_name +
")")
.c_str());
plugin_inputs.emplace_back(shuffler_slice->getOutput(0));
} else {
plugin_inputs.emplace_back(input);
}
std::string pos_name;
if (engine_->Has("ernie_pos_name")) {
pos_name = engine_->Get<std::string>("ernie_pos_name");
} else {
// hard code for compatibility
pos_name = engine_->network()->getInput(2)->getName();
}
plugin_inputs.emplace_back(
engine_->GetITensor(pos_name)); // cu_seqlens, eval_placeholder_2
// bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SpecialSlicePluginDynamic* plugin =
new plugin::SpecialSlicePluginDynamic();
layer = engine_->AddDynamicPlugin(plugin_inputs.data(),
plugin_inputs.size(), plugin);
} else {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
int decrease_axis =
decrease_axises.size() == 0 ? -1 : decrease_axises[0];
plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic(
starts, ends, axes, decrease_axis, with_fp16);
layer = engine_->AddDynamicPlugin(&input, 1, plugin);
}
} else { } else {
bool with_fp16 = bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Convert Transformer Input(pos_id, max_seqlen).
*/
class TransformerInputConvert : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "Convert Transformer Input(pos_id, max_seqlen), use "
"transformer_input_convert_plugin";
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(platform::errors::Fatal(
"transformer_input_convert_op: If you want to use transformer, must "
"be with dynamic shape"));
}
framework::OpDesc op_desc(op, nullptr);
auto input_name = op_desc.Input("Input").front();
auto* input = engine_->GetITensor(input_name);
int input_num = op_desc.Input("Input").size();
// tensorrt_subgraph_pass will rename tensor
// auto pos_id_name = op_desc.Output("PosId").front();
// auto max_seqlen_name = op_desc.Output("MaxSeqlen").front();
auto pos_id_name = "pos_id_tensor";
auto max_seqlen_name = "max_seqlen_tensor";
plugin::TransformerInputConvertPlugin* plugin =
new plugin::TransformerInputConvertPlugin();
nvinfer1::ILayer* layer =
engine_->AddDynamicPlugin(&input, input_num, plugin);
RreplenishLayerAndOutput(layer, "transformer_input_convert",
{pos_id_name, max_seqlen_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(transformer_input_convert, TransformerInputConvert);
...@@ -410,14 +410,19 @@ class TensorRTEngine { ...@@ -410,14 +410,19 @@ class TensorRTEngine {
suffix_counter += 1; suffix_counter += 1;
} }
void SetUseOSS(bool use_oss) { use_oss_ = use_oss; } void SetUseOSS(bool use_varseqlen) { use_varseqlen_ = use_varseqlen; }
void SetUseDLA(bool use_dla) { use_dla_ = use_dla; } void SetUseDLA(bool use_dla) { use_dla_ = use_dla; }
void SetDLACore(int dla_core) { dla_core_ = dla_core; } void SetDLACore(int dla_core) { dla_core_ = dla_core; }
void SetWithErnie(bool with_ernie) { with_ernie_ = with_ernie; } void SetWithErnie(bool with_ernie) { with_ernie_ = with_ernie; }
void SetWithInterleaved(bool with_interleaved) { void SetWithInterleaved(bool with_interleaved) {
with_interleaved_ = with_interleaved; with_interleaved_ = with_interleaved;
} }
void SetTransformerPosid(std::string tensorrt_transformer_posid) {
tensorrt_transformer_posid_ = tensorrt_transformer_posid;
}
void SetTransformerMaskid(std::string tensorrt_transformer_maskid) {
tensorrt_transformer_maskid_ = tensorrt_transformer_maskid;
}
void ClearWeights() { void ClearWeights() {
for (auto& weight_pair : weight_map) { for (auto& weight_pair : weight_map) {
weight_pair.second.reset(nullptr); weight_pair.second.reset(nullptr);
...@@ -488,9 +493,15 @@ class TensorRTEngine { ...@@ -488,9 +493,15 @@ class TensorRTEngine {
return ret; return ret;
} }
bool use_oss() { return use_oss_; } bool use_varseqlen() { return use_varseqlen_; }
bool with_ernie() { return with_ernie_; } bool with_ernie() { return with_ernie_; }
bool with_interleaved() { return with_interleaved_; } bool with_interleaved() { return with_interleaved_; }
std::string tensorrt_transformer_posid() {
return tensorrt_transformer_posid_;
}
std::string tensorrt_transformer_maskid() {
return tensorrt_transformer_maskid_;
}
bool disable_trt_plugin_fp16() { return disable_trt_plugin_fp16_; } bool disable_trt_plugin_fp16() { return disable_trt_plugin_fp16_; }
bool with_dynamic_shape() { return with_dynamic_shape_; } bool with_dynamic_shape() { return with_dynamic_shape_; }
AnalysisConfig::Precision precision() { return precision_; } AnalysisConfig::Precision precision() { return precision_; }
...@@ -612,11 +623,13 @@ class TensorRTEngine { ...@@ -612,11 +623,13 @@ class TensorRTEngine {
ShapeMapType max_input_shape_; ShapeMapType max_input_shape_;
ShapeMapType optim_input_shape_; ShapeMapType optim_input_shape_;
bool disable_trt_plugin_fp16_{false}; bool disable_trt_plugin_fp16_{false};
bool use_oss_{false}; bool use_varseqlen_{false};
bool use_dla_{false}; bool use_dla_{false};
int dla_core_{0}; int dla_core_{0};
bool with_ernie_{false}; bool with_ernie_{false};
bool with_interleaved_{false}; bool with_interleaved_{false};
std::string tensorrt_transformer_posid_;
std::string tensorrt_transformer_maskid_;
nvinfer1::ILogger& logger_; nvinfer1::ILogger& logger_;
// max data size for the buffers. // max data size for the buffers.
......
...@@ -125,7 +125,10 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -125,7 +125,10 @@ struct SimpleOpTypeSetTeller : public Teller {
"strided_slice", "strided_slice",
"fused_preln_embedding_eltwise_layernorm", "fused_preln_embedding_eltwise_layernorm",
"roll", "roll",
"preln_skip_layernorm"}; "preln_skip_layernorm",
"transformer_input_convert",
"recover_padding",
"remove_padding"};
std::unordered_set<std::string> teller_set{ std::unordered_set<std::string> teller_set{
"mul", "mul",
"matmul", "matmul",
...@@ -194,7 +197,10 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -194,7 +197,10 @@ struct SimpleOpTypeSetTeller : public Teller {
"fused_preln_embedding_eltwise_layernorm", "fused_preln_embedding_eltwise_layernorm",
"preln_skip_layernorm", "preln_skip_layernorm",
"roll", "roll",
"multiclass_nms3"}; "multiclass_nms3",
"transformer_input_convert",
"recover_padding",
"remove_padding"};
}; };
bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
......
...@@ -4,7 +4,7 @@ nv_library(tensorrt_plugin ...@@ -4,7 +4,7 @@ nv_library(tensorrt_plugin
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu hard_swish_op_plugin.cu stack_op_plugin.cu
anchor_generator_op_plugin.cu anchor_generator_op_plugin.cu
yolo_box_op_plugin.cu yolo_box_op_plugin.cu
yolo_box_head_op_plugin.cu yolo_box_head_op_plugin.cu
...@@ -14,6 +14,9 @@ nv_library(tensorrt_plugin ...@@ -14,6 +14,9 @@ nv_library(tensorrt_plugin
pool3d_op_plugin.cu pool3d_op_plugin.cu
deformable_conv_op_plugin.cu deformable_conv_op_plugin.cu
matmul_op_int8_plugin.cu matmul_op_int8_plugin.cu
transformer_input_convert_plugin.cu
remove_padding_plugin.cu
recover_padding_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
__global__ void RecoverPaddingKernel(const float* input0, const int32_t* input1,
float* output) {
int word_id = blockIdx.x * gridDim.y + blockIdx.y;
int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x];
if (blockIdx.y < seqence_length) {
output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x +
threadIdx.x] =
input0[(input1[blockIdx.x] + blockIdx.y) * gridDim.z * blockDim.x +
blockIdx.z * blockDim.x + threadIdx.x];
} else {
output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x +
threadIdx.x] = 0;
}
}
nvinfer1::DataType RecoverPaddingPlugin::getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
return input_types[0];
}
nvinfer1::DimsExprs RecoverPaddingPlugin::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs output_dims{};
output_dims.nbDims = 3;
const auto* one = exprBuilder.constant(1);
output_dims.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kSUB,
*inputs[1].d[0], *one);
output_dims.d[1] = inputs[2].d[1];
output_dims.d[2] = inputs[0].d[1];
return output_dims;
}
bool RecoverPaddingPlugin::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(nbInputs, 3,
platform::errors::InvalidArgument("Must have 3 inputs, "
"but got %d input(s). ",
nbInputs));
PADDLE_ENFORCE_EQ(nbOutputs, getNbOutputs(),
platform::errors::InvalidArgument("Must have 1 output, "
"but got %d output(s). ",
nbOutputs));
if (pos == 1) { // PosId, MaxSeqlen
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// == nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format ==
// nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kINT8 && inOut[pos].format ==
// nvinfer1::TensorFormat::kCHW32);
}
void RecoverPaddingPlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT {}
void RecoverPaddingPlugin::attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {}
void RecoverPaddingPlugin::detachFromContext() TRT_NOEXCEPT {}
void RecoverPaddingPlugin::terminate() TRT_NOEXCEPT {}
int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
const auto input0_desc = inputDesc[0];
const auto input1_desc = inputDesc[1];
const auto input2_desc = inputDesc[2];
const float* input0 = static_cast<const float*>(inputs[0]);
const int32_t* input1 =
static_cast<const int32_t*>(inputs[1]); // pos_id_tensor
float* output = static_cast<float*>(outputs[0]);
const int32_t num_threads = 256;
const dim3 num_blocks(
input1_desc.dims.d[0] - 1, input2_desc.dims.d[1],
input0_desc.dims.d[1] / num_threads); // batchs, max sequnce length
// (mask_id.dims.d[1]),
// input.dims.d[1]/256
RecoverPaddingKernel<<<num_blocks, num_threads, 0, stream>>>(input0, input1,
output);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cassert>
#include <string>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class RecoverPaddingPlugin : public DynamicPluginTensorRT {
public:
RecoverPaddingPlugin() {}
RecoverPaddingPlugin(void const* serial_data, size_t serial_length) {}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
RecoverPaddingPlugin* ptr = new RecoverPaddingPlugin();
return ptr;
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "recover_padding_plugin";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT { return 0; }
void terminate() TRT_NOEXCEPT;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator)
TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
protected:
size_t getSerializationSize() const TRT_NOEXCEPT override { return 0; }
void serialize(void* buffer) const TRT_NOEXCEPT override {}
};
class RecoverPaddingPluginCreator : public nvinfer1::IPluginCreator {
public:
RecoverPaddingPluginCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override {
return "recover_padding_plugin";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* plugin_field)
TRT_NOEXCEPT override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(
const char* name, void const* serial_data,
size_t serial_length) TRT_NOEXCEPT override {
RecoverPaddingPlugin* obj =
new RecoverPaddingPlugin(serial_data, serial_length);
obj->setPluginNamespace(name);
return obj;
}
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
};
REGISTER_TRT_PLUGIN_V2(RecoverPaddingPluginCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
__global__ void RemovePaddingKernel(const float* input0, const int32_t* input1,
float* output) {
int word_id = blockIdx.x * gridDim.y + blockIdx.y;
int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x];
if (blockIdx.y < seqence_length) {
output[(input1[blockIdx.x] + blockIdx.y) * gridDim.z * blockDim.x +
blockIdx.z * blockDim.x + threadIdx.x] =
input0[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x +
threadIdx.x];
}
}
nvinfer1::DataType RemovePaddingPlugin::getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
return input_types[0];
}
nvinfer1::DimsExprs RemovePaddingPlugin::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs output_dims{};
output_dims.nbDims = 4;
output_dims.d[0] = inputs[2].d[0];
output_dims.d[1] = inputs[0].d[2];
output_dims.d[2] = exprBuilder.constant(1);
output_dims.d[3] = exprBuilder.constant(1);
return output_dims;
}
bool RemovePaddingPlugin::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(nbInputs, 3,
platform::errors::InvalidArgument("Must have 3 inputs, "
"but got %d input(s). ",
nbInputs));
PADDLE_ENFORCE_EQ(nbOutputs, getNbOutputs(),
platform::errors::InvalidArgument("Must have 1 output, "
"but got %d output(s). ",
nbOutputs));
if (pos == 1 || pos == 2) { // pos_id, work_id
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
// return (inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format
// == nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format ==
// nvinfer1::TensorFormat::kLINEAR)||
// (inOut[pos].type == nvinfer1::DataType::kINT8 && inOut[pos].format ==
// nvinfer1::TensorFormat::kCHW32);
}
void RemovePaddingPlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT {}
void RemovePaddingPlugin::attachToContext(
cudnnContext* cudnnContext, cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {}
void RemovePaddingPlugin::detachFromContext() TRT_NOEXCEPT {}
void RemovePaddingPlugin::terminate() TRT_NOEXCEPT {}
int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
const auto input_desc = inputDesc[0];
const float* input0 = static_cast<const float*>(inputs[0]);
const int32_t* input1 =
static_cast<const int32_t*>(inputs[1]); // pos_id_tensor
float* output = static_cast<float*>(outputs[0]);
const auto input0_desc = inputDesc[0];
const int32_t num_threads = 256;
const dim3 num_blocks(
input0_desc.dims.d[0], input0_desc.dims.d[1],
input0_desc.dims.d[2] /
num_threads); // batchs, max sequnce length, input.dims.d[2]/256
RemovePaddingKernel<<<num_blocks, num_threads, 0, stream>>>(input0, input1,
output);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include <stdio.h>
#include <cassert> #include <cassert>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
#if IS_TRT_VERSION_GE(6000) class RemovePaddingPlugin : public DynamicPluginTensorRT {
class SpecialSlicePluginDynamic : public DynamicPluginTensorRT {
public: public:
SpecialSlicePluginDynamic(); RemovePaddingPlugin() {}
SpecialSlicePluginDynamic(void const* serial_data, size_t serial_length);
~SpecialSlicePluginDynamic(); RemovePaddingPlugin(void const* serial_data, size_t serial_length) {}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
RemovePaddingPlugin* ptr = new RemovePaddingPlugin();
return ptr;
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "remove_padding_plugin";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT { return 0; }
void terminate() TRT_NOEXCEPT;
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut, const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override; int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT override; int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override; int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator)
TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override; cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType( nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int index, const nvinfer1::DataType* inputTypes,
int nbInputs) const TRT_NOEXCEPT override; int nbInputs) const TRT_NOEXCEPT override;
const char* getPluginType() const TRT_NOEXCEPT override; void destroy() TRT_NOEXCEPT override { delete this; }
int getNbOutputs() const TRT_NOEXCEPT override;
int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void* buffer) const TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override;
private: protected:
int axis_; size_t getSerializationSize() const TRT_NOEXCEPT override { return 0; }
int num_stack_;
void serialize(void* buffer) const TRT_NOEXCEPT override {}
}; };
class SpecialSlicePluginDynamicCreator : public nvinfer1::IPluginCreator { class RemovePaddingPluginCreator : public nvinfer1::IPluginCreator {
public: public:
SpecialSlicePluginDynamicCreator(); RemovePaddingPluginCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override; const char* getPluginName() const TRT_NOEXCEPT override {
const char* getPluginVersion() const TRT_NOEXCEPT override; return "remove_padding_plugin";
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override; }
nvinfer1::IPluginV2* createPlugin(const char* name, const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const nvinfer1::PluginFieldCollection* fc)
TRT_NOEXCEPT override; const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* plugin_field)
TRT_NOEXCEPT override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin( nvinfer1::IPluginV2* deserializePlugin(
const char* name, const void* serial_data, const char* name, void const* serial_data,
size_t serial_length) TRT_NOEXCEPT override; size_t serial_length) TRT_NOEXCEPT override {
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override; RemovePaddingPlugin* obj =
const char* getPluginNamespace() const TRT_NOEXCEPT override; new RemovePaddingPlugin(serial_data, serial_length);
obj->setPluginNamespace(name);
return obj;
}
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}
private: private:
std::string plugin_namespace_; std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
}; };
REGISTER_TRT_PLUGIN_V2(SpecialSlicePluginDynamicCreator); REGISTER_TRT_PLUGIN_V2(RemovePaddingPluginCreator);
#endif
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
SpecialSlicePluginDynamic::SpecialSlicePluginDynamic() {}
SpecialSlicePluginDynamic::SpecialSlicePluginDynamic(void const* serial_data,
size_t serial_length) {}
SpecialSlicePluginDynamic::~SpecialSlicePluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* SpecialSlicePluginDynamic::clone() const
TRT_NOEXCEPT {
return new SpecialSlicePluginDynamic();
}
const char* SpecialSlicePluginDynamic::getPluginType() const TRT_NOEXCEPT {
return "special_slice_plugin";
}
int SpecialSlicePluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }
int SpecialSlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
size_t SpecialSlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
size_t serialize_size = 0;
return serialize_size;
}
void SpecialSlicePluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {}
nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
nvinfer1::DimsExprs output(inputs[0]);
output.nbDims++;
for (int i = output.nbDims - 1; i > 1; i--) {
output.d[i] = inputs[0].d[i - 1];
}
auto one = expr_builder.constant(1);
output.d[1] = one;
output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB,
*inputs[1].d[0], *one);
// remove padding 1
output.nbDims -= 2;
return output;
}
void SpecialSlicePluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT {}
size_t SpecialSlicePluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT {
return 0;
}
void SpecialSlicePluginDynamic::destroy() TRT_NOEXCEPT { delete this; }
void SpecialSlicePluginDynamic::terminate() TRT_NOEXCEPT {}
bool SpecialSlicePluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* desc, int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
if (pos == 0) // slice tensor
return (desc[pos].type == nvinfer1::DataType::kHALF &&
desc[pos].format ==
nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
if (pos == 1) // cu_seqlen
return (desc[pos].type == nvinfer1::DataType::kINT32 &&
desc[pos].format == nvinfer1::TensorFormat::kLINEAR);
return (desc[pos].type == nvinfer1::DataType::kHALF &&
desc[pos].format ==
nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type ==
// nvinfer1::DataType::kFLOAT);
}
nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The index should be equal to 0"));
return input_types[0];
}
template <typename T>
__global__ void SpecialSliceKernel(const T* slice_input,
const int32_t* cu_seqlens, T* output) {
const int hidden = blockDim.x * gridDim.x;
const int hidden_id = blockIdx.x * blockDim.x + threadIdx.x;
const int batch_id = blockIdx.y;
output[batch_id * hidden + hidden_id] =
slice_input[cu_seqlens[batch_id] * hidden + hidden_id];
}
int SpecialSlicePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims; // (sum(S), hidden, 1, 1)
auto out_dims = output_desc[0].dims; // (batch, hidden, 1, 1)
PADDLE_ENFORCE_EQ(
input_desc[0].type, nvinfer1::DataType::kHALF,
platform::errors::InvalidArgument("Type of input should be half."));
const int32_t hidden = input_dims.d[1];
PADDLE_ENFORCE_EQ(hidden % 128, 0, platform::errors::InvalidArgument(
"hidden should be multiple of 128."));
constexpr int num_threads = 128;
const half* slice_input = static_cast<const half*>(inputs[0]);
const int32_t* cu_seqlens = static_cast<const int32_t*>(inputs[1]);
half* output = static_cast<half*>(outputs[0]);
const int32_t num_blocks_x = hidden / num_threads;
const int32_t num_blocks_y = out_dims.d[0]; // batchs
const dim3 num_blocks(num_blocks_x, num_blocks_y); // blocks
SpecialSliceKernel<<<num_blocks, num_threads, 0, stream>>>(
slice_input, cu_seqlens, output);
return cudaGetLastError() != cudaSuccess;
}
SpecialSlicePluginDynamicCreator::SpecialSlicePluginDynamicCreator() {}
const char* SpecialSlicePluginDynamicCreator::getPluginName() const
TRT_NOEXCEPT {
return "special_slice_plugin";
}
const char* SpecialSlicePluginDynamicCreator::getPluginVersion() const
TRT_NOEXCEPT {
return "1";
}
const nvinfer1::PluginFieldCollection*
SpecialSlicePluginDynamicCreator::getFieldNames() TRT_NOEXCEPT {
return &field_collection_;
}
nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT {
return new SpecialSlicePluginDynamic();
}
nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::deserializePlugin(
const char* name, const void* serial_data,
size_t serial_length) TRT_NOEXCEPT {
auto plugin = new SpecialSlicePluginDynamic(serial_data, serial_length);
return plugin;
}
void SpecialSlicePluginDynamicCreator::setPluginNamespace(
const char* lib_namespace) TRT_NOEXCEPT {
plugin_namespace_ = lib_namespace;
}
const char* SpecialSlicePluginDynamicCreator::getPluginNamespace() const
TRT_NOEXCEPT {
return plugin_namespace_.c_str();
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -65,7 +65,7 @@ TEST(PD_Config, gpu_interface) { ...@@ -65,7 +65,7 @@ TEST(PD_Config, gpu_interface) {
&min_shape_ptr, &max_shape_ptr, &min_shape_ptr, &max_shape_ptr,
&opt_shape_ptr, FALSE); &opt_shape_ptr, FALSE);
PD_ConfigDisableTensorRtOPs(config, 1, &ops_name); PD_ConfigDisableTensorRtOPs(config, 1, &ops_name);
PD_ConfigEnableTensorRtOSS(config); PD_ConfigEnableVarseqlen(config);
bool oss_enabled = PD_ConfigTensorRtOssEnabled(config); bool oss_enabled = PD_ConfigTensorRtOssEnabled(config);
EXPECT_TRUE(oss_enabled); EXPECT_TRUE(oss_enabled);
......
...@@ -210,7 +210,11 @@ std::shared_ptr<paddle_infer::Predictor> InitPredictor() { ...@@ -210,7 +210,11 @@ std::shared_ptr<paddle_infer::Predictor> InitPredictor() {
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape); opt_input_shape);
// erinie varlen must be used with oss // erinie varlen must be used with oss
config.EnableTensorRtOSS(); config.EnableVarseqlen();
paddle_infer::experimental::InternalUtils::SetTransformerPosid(&config,
input_name2);
paddle_infer::experimental::InternalUtils::SetTransformerMaskid(&config,
input_name3);
return paddle_infer::CreatePredictor(config); return paddle_infer::CreatePredictor(config);
} }
......
...@@ -68,7 +68,7 @@ std::shared_ptr<Predictor> InitPredictor() { ...@@ -68,7 +68,7 @@ std::shared_ptr<Predictor> InitPredictor() {
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape); opt_input_shape);
// erinie varlen must be used with oss // erinie varlen must be used with oss
config.EnableTensorRtOSS(); config.EnableVarseqlen();
return CreatePredictor(config); return CreatePredictor(config);
} }
......
...@@ -43,7 +43,7 @@ TEST(table_printer, output) { ...@@ -43,7 +43,7 @@ TEST(table_printer, output) {
table.InsertRow({"trt_precision", "fp32"}); table.InsertRow({"trt_precision", "fp32"});
table.InsertRow({"enable_dynamic_shape", "true"}); table.InsertRow({"enable_dynamic_shape", "true"});
table.InsertRow({"DisableTensorRtOPs", "{}"}); table.InsertRow({"DisableTensorRtOPs", "{}"});
table.InsertRow({"EnableTensorRtOSS", "ON"}); table.InsertRow({"EnableVarseqlen", "ON"});
table.InsertRow({"tensorrt_dla_enabled", "ON"}); table.InsertRow({"tensorrt_dla_enabled", "ON"});
table.InsetDivider(); table.InsetDivider();
......
...@@ -657,8 +657,9 @@ void BindAnalysisConfig(py::module *m) { ...@@ -657,8 +657,9 @@ void BindAnalysisConfig(py::module *m) {
py::arg("disable_trt_plugin_fp16") = false) py::arg("disable_trt_plugin_fp16") = false)
.def("tensorrt_dynamic_shape_enabled", .def("tensorrt_dynamic_shape_enabled",
&AnalysisConfig::tensorrt_dynamic_shape_enabled) &AnalysisConfig::tensorrt_dynamic_shape_enabled)
.def("enable_tensorrt_oss", &AnalysisConfig::EnableTensorRtOSS) .def("enable_tensorrt_varseqlen", &AnalysisConfig::EnableVarseqlen)
.def("tensorrt_oss_enabled", &AnalysisConfig::tensorrt_oss_enabled) .def("tensorrt_varseqlen_enabled",
&AnalysisConfig::tensorrt_varseqlen_enabled)
.def("collect_shape_range_info", &AnalysisConfig::CollectShapeRangeInfo) .def("collect_shape_range_info", &AnalysisConfig::CollectShapeRangeInfo)
.def("shape_range_info_path", &AnalysisConfig::shape_range_info_path) .def("shape_range_info_path", &AnalysisConfig::shape_range_info_path)
.def("shape_range_info_collected", .def("shape_range_info_collected",
......
...@@ -42,7 +42,7 @@ class InferencePassTest(unittest.TestCase): ...@@ -42,7 +42,7 @@ class InferencePassTest(unittest.TestCase):
self.enable_mkldnn = False self.enable_mkldnn = False
self.enable_mkldnn_bfloat16 = False self.enable_mkldnn_bfloat16 = False
self.enable_trt = False self.enable_trt = False
self.enable_tensorrt_oss = True self.enable_tensorrt_varseqlen = True
self.trt_parameters = None self.trt_parameters = None
self.dynamic_shape_params = None self.dynamic_shape_params = None
self.enable_lite = False self.enable_lite = False
...@@ -134,8 +134,8 @@ class InferencePassTest(unittest.TestCase): ...@@ -134,8 +134,8 @@ class InferencePassTest(unittest.TestCase):
self.dynamic_shape_params.max_input_shape, self.dynamic_shape_params.max_input_shape,
self.dynamic_shape_params.optim_input_shape, self.dynamic_shape_params.optim_input_shape,
self.dynamic_shape_params.disable_trt_plugin_fp16) self.dynamic_shape_params.disable_trt_plugin_fp16)
if self.enable_tensorrt_oss: if self.enable_tensorrt_varseqlen:
config.enable_tensorrt_oss() config.enable_tensorrt_varseqlen()
elif use_mkldnn: elif use_mkldnn:
config.enable_mkldnn() config.enable_mkldnn()
......
...@@ -46,7 +46,7 @@ class QuantDequantTest(unittest.TestCase): ...@@ -46,7 +46,7 @@ class QuantDequantTest(unittest.TestCase):
self.enable_mkldnn = False self.enable_mkldnn = False
self.enable_mkldnn_bfloat16 = False self.enable_mkldnn_bfloat16 = False
self.enable_trt = False self.enable_trt = False
self.enable_tensorrt_oss = True self.enable_tensorrt_varseqlen = True
self.trt_parameters = None self.trt_parameters = None
self.dynamic_shape_params = None self.dynamic_shape_params = None
self.enable_lite = False self.enable_lite = False
...@@ -184,8 +184,8 @@ class QuantDequantTest(unittest.TestCase): ...@@ -184,8 +184,8 @@ class QuantDequantTest(unittest.TestCase):
self.dynamic_shape_params.max_input_shape, self.dynamic_shape_params.max_input_shape,
self.dynamic_shape_params.optim_input_shape, self.dynamic_shape_params.optim_input_shape,
self.dynamic_shape_params.disable_trt_plugin_fp16) self.dynamic_shape_params.disable_trt_plugin_fp16)
if self.enable_tensorrt_oss: if self.enable_tensorrt_varseqlen:
config.enable_tensorrt_oss() config.enable_tensorrt_varseqlen()
elif use_mkldnn: elif use_mkldnn:
config.enable_mkldnn() config.enable_mkldnn()
......
...@@ -179,7 +179,7 @@ def multiclass_nms(bboxes, ...@@ -179,7 +179,7 @@ def multiclass_nms(bboxes,
class TensorRTMultiClassNMS3Test(InferencePassTest): class TensorRTMultiClassNMS3Test(InferencePassTest):
def setUp(self): def setUp(self):
self.enable_trt = True self.enable_trt = True
self.enable_tensorrt_oss = True self.enable_tensorrt_varseqlen = True
self.precision = AnalysisConfig.Precision.Float32 self.precision = AnalysisConfig.Precision.Float32
self.serialize = False self.serialize = False
self.bs = 1 self.bs = 1
...@@ -291,8 +291,8 @@ class TensorRTMultiClassNMS3Test(InferencePassTest): ...@@ -291,8 +291,8 @@ class TensorRTMultiClassNMS3Test(InferencePassTest):
self.background = 7 self.background = 7
self.run_test() self.run_test()
def test_disable_oss(self): def test_disable_varseqlen(self):
self.diable_tensorrt_oss = False self.diable_tensorrt_varseqlen = False
self.run_test() self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册