diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index d000dc70853659d27885721e2d1c1863f49d3067..b430a409e99657929da94ac367a8b68f433f1871 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -107,6 +107,9 @@ if(WITH_TENSORRT) pass_library(trt_map_matmul_to_mul_pass inference) pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference) + pass_library(set_transformer_input_convert_pass inference) + pass_library(remove_padding_recover_padding_pass inference) + pass_library(delete_remove_padding_recover_padding_pass inference) endif() if(WITH_GPU OR WITH_ROCM) @@ -161,6 +164,7 @@ if(WITH_IPU) pass_library(infer_shape_pass base DIR ipu) pass_library(delete_scale_op_pass base DIR ipu) pass_library(avg_shard_pass base DIR ipu) + pass_library(transfer_cast_op_pass base DIR ipu) endif() cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector ) diff --git a/paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.cc b/paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..63233e0b584b2216fa58801e6e1919d755801f09 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.h" + +#include + +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +void RecoverPadding::operator()() { + // Create nodes for recover_padding. + auto *recover_padding_input = + pattern->NewNode(recover_padding_input_repr()) + ->assert_is_op_input("recover_padding", "Input"); + auto *recover_padding_op = pattern->NewNode(recover_padding_op_repr()) + ->assert_is_op("recover_padding"); + auto *recover_padding_out = + pattern->NewNode(recover_padding_out_repr()) + ->assert_is_op_output("recover_padding", "Out"); + + // Add links for recover_padding op. + recover_padding_op->LinksFrom({recover_padding_input}) + .LinksTo({recover_padding_out}); +} +} // namespace patterns + +void DeleteRemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init(name_scope_, graph); + int found_subgraph_count = 0; + + // + GraphPatternDetector gpd; + patterns::RecoverPadding recover_padding( + gpd.mutable_pattern(), "delete_remove_padding_recover_padding_pass"); + recover_padding(); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + VLOG(3) << "delete_remove_padding_recover_padding_pass"; + + GET_IR_NODE_FROM_SUBGRAPH(recover_padding_input, recover_padding_input, + recover_padding); + GET_IR_NODE_FROM_SUBGRAPH(recover_padding_op, recover_padding_op, + recover_padding); + GET_IR_NODE_FROM_SUBGRAPH(recover_padding_out, recover_padding_out, + recover_padding); + + std::unordered_set del_node_set; + + bool delete_recover_padding = true; + for (size_t i = 0; i < recover_padding_out->outputs.size(); ++i) { + if (recover_padding_out->outputs[i]->Name() == + "remove_padding") { // op_node + auto *remove_padding_out_node = + recover_padding_out->outputs[i]->outputs[0]; // var_node + auto *out_op_node = remove_padding_out_node->outputs[0]; // op_node + IR_NODE_LINK_TO(recover_padding_input, out_op_node); + del_node_set.insert(recover_padding_out->outputs[i]); + del_node_set.insert(remove_padding_out_node); + out_op_node->Op()->RenameInput(remove_padding_out_node->Name(), + recover_padding_input->Name()); + found_subgraph_count++; + } else { + delete_recover_padding = false; + } + } + if (delete_recover_padding) { + del_node_set.insert(recover_padding_op); + del_node_set.insert(recover_padding_out); + } + GraphSafeRemoveNodes(graph, del_node_set); + }; + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_remove_padding_recover_padding_pass, + paddle::framework::ir::DeleteRemovePaddingRecoverPaddingPass); diff --git a/paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.h b/paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..3504b124c91d12d0cbd1165b42d487e595bd90b3 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +class Graph; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +struct RecoverPadding : public PatternBase { + RecoverPadding(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "recover_padding") {} + + void operator()(); + + PATTERN_DECL_NODE(recover_padding_input); + PATTERN_DECL_NODE(recover_padding_op); + PATTERN_DECL_NODE(recover_padding_out); +}; +} // namespace patterns + +class DeleteRemovePaddingRecoverPaddingPass : public FusePassBase { + public: + DeleteRemovePaddingRecoverPaddingPass() {} + virtual ~DeleteRemovePaddingRecoverPaddingPass() {} + + protected: + void ApplyImpl(Graph *graph) const; + const std::string name_scope_{"delete_remove_padding_recover_padding_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..67dfe074dc075d7517ff46ea227684bc8a8fd441 --- /dev/null +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc @@ -0,0 +1,298 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h" + +#include + +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +void SkipLayernorm::operator()() { + // Create nodes for skip_layernorm. + auto* skip_layernorm_x = pattern->NewNode(skip_layernorm_x_repr()) + ->assert_is_op_input("skip_layernorm", "X"); + auto* skip_layernorm_y = pattern->NewNode(skip_layernorm_y_repr()) + ->assert_is_op_input("skip_layernorm", "Y"); + auto* skip_layernorm_op = pattern->NewNode(skip_layernorm_op_repr()) + ->assert_is_op("skip_layernorm"); + auto* skip_layernorm_out = pattern->NewNode(skip_layernorm_out_repr()) + ->assert_is_op_output("skip_layernorm", "Out"); + + // Add links for skip_layernorm op. + skip_layernorm_op->LinksFrom({skip_layernorm_x, skip_layernorm_y}) + .LinksTo({skip_layernorm_out}); +} + +void MultiheadMatmul::operator()() { + // Create nodes for multihead_matmul. + auto* multihead_matmul_input = + pattern->NewNode(multihead_matmul_input_repr()) + ->assert_is_op_input("multihead_matmul", "Input"); + auto* multihead_matmul_op = pattern->NewNode(multihead_matmul_op_repr()) + ->assert_is_op("multihead_matmul"); + auto* multihead_matmul_out = + pattern->NewNode(multihead_matmul_out_repr()) + ->assert_is_op_output("multihead_matmul", "Out"); + + // Add links for multihead_matmul op. + multihead_matmul_op->LinksFrom({multihead_matmul_input}) + .LinksTo({multihead_matmul_out}); +} + +void Fc::operator()() { + // Create nodes for fc. + auto* fc_input = + pattern->NewNode(fc_input_repr())->assert_is_op_input("fc", "Input"); + auto* fc_op = pattern->NewNode(fc_op_repr())->assert_is_op("fc"); + auto* fc_out = + pattern->NewNode(fc_out_repr())->assert_is_op_output("fc", "Out"); + + // Add links for fc op. + fc_op->LinksFrom({fc_input}).LinksTo({fc_out}); +} + +void Activation::operator()() { + // Create nodes for activation. + std::unordered_set activation_ops{"relu", "sigmoid", "tanh"}; + auto* activation_input = pattern->NewNode(activation_input_repr()) + ->assert_is_ops_input(activation_ops); + auto* activation_op = + pattern->NewNode(activation_op_repr())->assert_is_ops(activation_ops); + auto* activation_out = pattern->NewNode(activation_out_repr()) + ->assert_is_ops_output(activation_ops); + + // Add links for activation op. + activation_op->LinksFrom({activation_input}).LinksTo({activation_out}); +} +} // namespace patterns + +void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + int found_subgraph_count = 0; + + // Create an remove_padding op node + auto insert_remove_padding_op = [&](Node* input_node, Node* op_node) { + // create op, var in graph + OpDesc remove_padding; + std::string remove_padding_out_name = + input_node->Name() + ".remove_padding"; + + VarDesc remove_padding_out(remove_padding_out_name); + remove_padding_out.SetDataType(input_node->Var()->GetDataType()); + remove_padding_out.SetShape(input_node->Var()->GetShape()); + remove_padding_out.SetPersistable(false); + + // remove_padding_op + remove_padding.SetType("remove_padding"); + + // input + remove_padding.SetInput("Input", {input_node->Name()}); + + // output + remove_padding.SetOutput("Out", {remove_padding_out_name}); + + auto remove_padding_op_node = graph->CreateOpNode(&remove_padding); + auto remove_padding_out_node = graph->CreateVarNode(&remove_padding_out); + + // replace link + for (size_t i = 0; i < input_node->outputs.size(); ++i) { + if (input_node->outputs[i] == op_node) { + input_node->outputs[i] = remove_padding_op_node; + remove_padding_op_node->inputs.push_back(input_node); + } + } + + // link node + IR_NODE_LINK_TO(remove_padding_op_node, remove_padding_out_node); + + // replace link + for (size_t i = 0; i < op_node->inputs.size(); ++i) { + if (op_node->inputs[i] == input_node) { + op_node->inputs[i] = remove_padding_out_node; + remove_padding_out_node->outputs.push_back(op_node); + } + } + + // create variable in scope + scope->Var(remove_padding_out_name); + auto* remove_padding_out_tensor = + scope->FindVar(remove_padding_out_name)->GetMutable(); + remove_padding_out_tensor->mutable_data(platform::CUDAPlace()); + + // rename + op_node->Op()->RenameInput(input_node->Name(), + remove_padding_out_node->Name()); + }; + + // create an remove_padding op node + auto insert_recover_padding_op = [&](Node* op_node, Node* out_node) { + // create op, var in graph + OpDesc recover_padding; + std::string recover_padding_input_name = + out_node->Name() + ".recover_padding"; + VarDesc recover_padding_input(recover_padding_input_name); + recover_padding_input.SetDataType(out_node->Var()->GetDataType()); + recover_padding_input.SetShape(out_node->Var()->GetShape()); + recover_padding_input.SetPersistable(false); + + // recover_padding_op + recover_padding.SetType("recover_padding"); + + // input + recover_padding.SetInput("Input", {recover_padding_input_name}); + + // output + recover_padding.SetOutput("Out", {out_node->Name()}); + + auto recover_padding_op_node = graph->CreateOpNode(&recover_padding); + auto recover_padding_input_node = + graph->CreateVarNode(&recover_padding_input); + + // replace link + for (size_t i = 0; i < op_node->outputs.size(); ++i) { + if (op_node->outputs[i] == out_node) { + op_node->outputs[i] = recover_padding_input_node; + recover_padding_input_node->inputs.push_back(op_node); + } + } + + // link node + IR_NODE_LINK_TO(recover_padding_input_node, recover_padding_op_node); + + // replace link + for (size_t i = 0; i < out_node->inputs.size(); ++i) { + if (out_node->inputs[i] == op_node) { + out_node->inputs[i] = recover_padding_op_node; + recover_padding_op_node->outputs.push_back(out_node); + } + } + + // create variable in scope + scope->Var(recover_padding_input_name); + auto* recover_padding_input_tensor = + scope->FindVar(recover_padding_input_name)->GetMutable(); + recover_padding_input_tensor->mutable_data(platform::CUDAPlace()); + + // rename + op_node->Op()->RenameOutput(out_node->Name(), recover_padding_input_name); + }; + + GraphPatternDetector gpd1; + patterns::SkipLayernorm skip_layernorm(gpd1.mutable_pattern(), + "remove_padding_recover_padding_pass"); + skip_layernorm(); + + auto handler1 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(3) << "remove_padding_recover_padding_pass for transformer: " + "skip_layernorm"; + + GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_x, skip_layernorm_x, + skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_y, skip_layernorm_y, + skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_op, skip_layernorm_op, + skip_layernorm); + GET_IR_NODE_FROM_SUBGRAPH(skip_layernorm_out, skip_layernorm_out, + skip_layernorm); + + insert_remove_padding_op(skip_layernorm_x, skip_layernorm_op); + insert_remove_padding_op(skip_layernorm_y, skip_layernorm_op); + insert_recover_padding_op(skip_layernorm_op, skip_layernorm_out); + + found_subgraph_count++; + }; + gpd1(graph, handler1); + + GraphPatternDetector gpd2; + patterns::MultiheadMatmul multihead_matmul( + gpd2.mutable_pattern(), "remove_padding_recover_padding_pass"); + multihead_matmul(); + + auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(3) << "remove_padding_recover_padding_pass for transformer: " + "multihead_matmul"; + + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input, + multihead_matmul); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_op, multihead_matmul_op, + multihead_matmul); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, + multihead_matmul); + + insert_remove_padding_op(multihead_matmul_input, multihead_matmul_op); + insert_recover_padding_op(multihead_matmul_op, multihead_matmul_out); + + found_subgraph_count++; + }; + gpd2(graph, handler2); + + GraphPatternDetector gpd3; + patterns::Fc fc(gpd3.mutable_pattern(), + "remove_padding_recover_padding_pass"); + fc(); + + auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(3) << "remove_padding_recover_padding_pass for transformer: fc"; + + GET_IR_NODE_FROM_SUBGRAPH(fc_input, fc_input, fc); + GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc_op, fc); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, fc); + + insert_remove_padding_op(fc_input, fc_op); + insert_recover_padding_op(fc_op, fc_out); + + found_subgraph_count++; + }; + gpd3(graph, handler3); + + GraphPatternDetector gpd4; + patterns::Activation activation(gpd4.mutable_pattern(), + "remove_padding_recover_padding_pass"); + activation(); + + auto handler4 = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(3) + << "remove_padding_recover_padding_pass for transformer: activation"; + + GET_IR_NODE_FROM_SUBGRAPH(activation_input, activation_input, activation); + GET_IR_NODE_FROM_SUBGRAPH(activation_op, activation_op, activation); + GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, activation); + + insert_remove_padding_op(activation_input, activation_op); + insert_recover_padding_op(activation_op, activation_out); + + found_subgraph_count++; + }; + gpd4(graph, handler4); + + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(remove_padding_recover_padding_pass, + paddle::framework::ir::RemovePaddingRecoverPaddingPass); diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..d7ccfc75c2000824efc378444267c77174259925 --- /dev/null +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +class Graph; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct SkipLayernorm : public PatternBase { + SkipLayernorm(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "skip_layernorm") {} + + void operator()(); + + PATTERN_DECL_NODE(skip_layernorm_x); + PATTERN_DECL_NODE(skip_layernorm_y); + PATTERN_DECL_NODE(skip_layernorm_op); + PATTERN_DECL_NODE(skip_layernorm_out); +}; + +struct MultiheadMatmul : public PatternBase { + MultiheadMatmul(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "multihead_matmul") {} + + void operator()(); + + PATTERN_DECL_NODE(multihead_matmul_input); + PATTERN_DECL_NODE(multihead_matmul_op); + PATTERN_DECL_NODE(multihead_matmul_out); +}; + +struct Fc : public PatternBase { + Fc(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "fc") {} + + void operator()(); + + PATTERN_DECL_NODE(fc_input); + PATTERN_DECL_NODE(fc_op); + PATTERN_DECL_NODE(fc_out); +}; + +struct Activation : public PatternBase { + Activation(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "activation") {} + + void operator()(); + + PATTERN_DECL_NODE(activation_input); + PATTERN_DECL_NODE(activation_op); + PATTERN_DECL_NODE(activation_out); +}; +} // namespace patterns + +class RemovePaddingRecoverPaddingPass : public FusePassBase { + public: + RemovePaddingRecoverPaddingPass() {} + virtual ~RemovePaddingRecoverPaddingPass() {} + + protected: + void ApplyImpl(Graph *graph) const; + const std::string name_scope_{"remove_padding_recover_padding_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc b/paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..37e77bc134d3c645a108da6b063556c32bf48960 --- /dev/null +++ b/paddle/fluid/framework/ir/set_transformer_input_convert_pass.cc @@ -0,0 +1,161 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/set_transformer_input_convert_pass.h" + +#include + +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +SetTransformerInputConvertPass::SetTransformerInputConvertPass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .End(); +} +namespace patterns { + +void SetTransformerInputConvert::operator()() { + std::unordered_set lookup_table_ops{"lookup_table", + "lookup_table_v2"}; + // Create nodes for lookup_table1 op. + auto *lookup_table1_x = pattern->NewNode(lookup_table1_x_repr()) + ->assert_is_ops_input(lookup_table_ops, "Ids"); + auto *lookup_table1_w = pattern->NewNode(lookup_table1_w_repr()) + ->assert_is_ops_input(lookup_table_ops, "W"); + auto *lookup_table1_op = + pattern->NewNode(lookup_table1_repr())->assert_is_ops(lookup_table_ops); + auto *lookup_table1_out = pattern->NewNode(lookup_table1_out_repr()) + ->assert_is_ops_output(lookup_table_ops) + ->AsIntermediate() + ->assert_is_op_input("elementwise_add", "X"); + + // Create nodes for lookup_table2 op. + auto *lookup_table2_x = pattern->NewNode(lookup_table2_x_repr()) + ->assert_is_ops_input(lookup_table_ops, "Ids"); + auto *lookup_table2_w = pattern->NewNode(lookup_table2_w_repr()) + ->assert_is_ops_input(lookup_table_ops, "W"); + auto *lookup_table2_op = + pattern->NewNode(lookup_table2_repr())->assert_is_ops(lookup_table_ops); + auto *lookup_table2_out = pattern->NewNode(lookup_table2_out_repr()) + ->assert_is_ops_output(lookup_table_ops) + ->AsIntermediate() + ->assert_is_op_input("elementwise_add", "Y"); + + // Create nodes for elementwise_add op. + auto *elementwise_op = + pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); + auto *elementwise_out = pattern->NewNode(elementwise_out_repr()) + ->AsOutput() + ->assert_is_only_output_of_op("elementwise_add"); + + // links nodes. + lookup_table1_op->LinksFrom({lookup_table1_x, lookup_table1_w}) + .LinksTo({lookup_table1_out}); + lookup_table2_op->LinksFrom({lookup_table2_x, lookup_table2_w}) + .LinksTo({lookup_table2_out}); + elementwise_op->LinksFrom({lookup_table1_out, lookup_table2_out}) + .LinksTo({elementwise_out}); +} + +} // namespace patterns + +void SetTransformerInputConvertPass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init(name_scope_, graph); + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + patterns::SetTransformerInputConvert fused_pattern( + gpd.mutable_pattern(), "transformer_input_convert_pass"); + fused_pattern(); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "transformer_input_convert_pass in op compat failed."; + return; + } + + VLOG(3) << "transformer_input_convert_pass for pos_id, max_seqlen"; + + GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, fused_pattern); + + // create op, var in graph + OpDesc new_desc; + new_desc.SetType("transformer_input_convert"); + + // inputs + new_desc.SetInput("X", {lookup_table2_x->Name()}); + + // outputs + std::vector output_0 = {"pos_id_tensor"}; + std::vector output_1 = {"max_seqlen_tensor"}; + new_desc.SetOutput("PosId", output_0); + new_desc.SetOutput("MaxSeqlen", output_1); + + std::string transformer_input_convert_out0_name = "pos_id_tensor"; + std::string transformer_input_convert_out1_name = "max_seqlen_tensor"; + VarDesc transformer_input_convert_out0(transformer_input_convert_out0_name); + VarDesc transformer_input_convert_out1(transformer_input_convert_out1_name); + transformer_input_convert_out0.SetDataType(proto::VarType::INT32); + transformer_input_convert_out1.SetDataType(proto::VarType::INT32); + transformer_input_convert_out0.SetShape({-1}); + transformer_input_convert_out1.SetShape({-1}); + transformer_input_convert_out0.SetPersistable(false); + transformer_input_convert_out1.SetPersistable(false); + + auto new_op_node = graph->CreateOpNode(&new_desc); + auto transformer_input_convert_out0_node = + graph->CreateVarNode(&transformer_input_convert_out0); + auto transformer_input_convert_out1_node = + graph->CreateVarNode(&transformer_input_convert_out1); + + // needn't create variable in scope + + IR_NODE_LINK_TO(lookup_table2_x, new_op_node); + IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out0_node); + IR_NODE_LINK_TO(new_op_node, transformer_input_convert_out1_node); + + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(set_transformer_input_convert_pass, + paddle::framework::ir::SetTransformerInputConvertPass); +REGISTER_PASS_CAPABILITY(set_transformer_input_convert_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("lookup_table", 1) + .LE("lookup_table_v2", 1) + .LE("elementweise_add", 1)); diff --git a/paddle/fluid/framework/ir/set_transformer_input_convert_pass.h b/paddle/fluid/framework/ir/set_transformer_input_convert_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..5a5843e810f9a0fc2c0708145f04b9aea286bc18 --- /dev/null +++ b/paddle/fluid/framework/ir/set_transformer_input_convert_pass.h @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +class Graph; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +// in_var emb in_var emb +// | | | | +// lookup_table lookup_table +// | | +// lkt_var lkt_var +// \ / +// elementwise_add +// | +// elt_out_var +// +struct SetTransformerInputConvert : public PatternBase { + SetTransformerInputConvert(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "transformer_input_convert") {} + + void operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(lookup_table1); + PATTERN_DECL_NODE(lookup_table2); + PATTERN_DECL_NODE(elementwise); + + // declare variable node's name + PATTERN_DECL_NODE(lookup_table1_x); + PATTERN_DECL_NODE(lookup_table1_w); + PATTERN_DECL_NODE(lookup_table1_out); + PATTERN_DECL_NODE(lookup_table2_x); + PATTERN_DECL_NODE(lookup_table2_w); + PATTERN_DECL_NODE(lookup_table2_out); + PATTERN_DECL_NODE(elementwise_out); +}; +} // namespace patterns + +class SetTransformerInputConvertPass : public FusePassBase { + public: + SetTransformerInputConvertPass(); + virtual ~SetTransformerInputConvertPass() {} + + protected: + void ApplyImpl(Graph *graph) const; + const std::string name_scope_{"transformer_input_convert_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index e4fc52b6fa74427b1f24b194dffea6f39e2b4692..059a9cb21e1d5e0b6925e85e41c963b91292ec53 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -377,12 +377,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( trt_engine->SetUseDLA(Get("trt_use_dla")); trt_engine->SetDLACore(Get("trt_dla_core")); trt_engine->SetUseInspector(Get("use_inspector")); - - trt_engine->SetWithErnie( - (graph->Has(framework::ir::kEmbEltwiseLayernormPass) && - graph->Has(framework::ir::kMultiheadMatmulPass)) || - (graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) && - graph->Has(framework::ir::kMultiheadMatmulPass))); + trt_engine->SetWithErnie(graph->Has(framework::ir::kMultiheadMatmulPass)); if (use_static_engine) { trt_engine_serialized_data = GetTrtEngineSerializedData( diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index f59494628ad7e5327523b54022323327f161b773..4c3587e54036bcb0b33dfa7bf247cf1f92253239 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -98,18 +98,21 @@ const std::vector kTRTSubgraphPasses({ "multihead_matmul_fuse_pass_v3", // "skip_layernorm_fuse_pass", // "preln_skip_layernorm_fuse_pass", // - "conv_bn_fuse_pass", // - "unsqueeze2_eltwise_fuse_pass", // - "trt_squeeze2_matmul_fuse_pass", // - "trt_reshape2_matmul_fuse_pass", // - "trt_flatten2_matmul_fuse_pass", // - "trt_map_matmul_v2_to_mul_pass", // - "trt_map_matmul_v2_to_matmul_pass", // - "trt_map_matmul_to_mul_pass", // - "fc_fuse_pass", // - "conv_elementwise_add_fuse_pass", // - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + // "set_transformer_input_convert_pass", // + "conv_bn_fuse_pass", // + "unsqueeze2_eltwise_fuse_pass", // + "trt_squeeze2_matmul_fuse_pass", // + "trt_reshape2_matmul_fuse_pass", // + "trt_flatten2_matmul_fuse_pass", // + "trt_map_matmul_v2_to_mul_pass", // + "trt_map_matmul_v2_to_matmul_pass", // + "trt_map_matmul_to_mul_pass", // + "fc_fuse_pass", // + "conv_elementwise_add_fuse_pass", // + // "remove_padding_recover_padding_pass", // + // "delete_remove_padding_recover_padding_pass", // + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 // cudnn8.0 has memory leak problem in conv + eltwise + act, so we