diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 5c9841aef1707bfd929e5313c1082aef3aacd87d..3a2ae0ff21788d004a28946b0c08e49bb18bc021 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -174,6 +174,7 @@ if(WITH_TENSORRT) pass_library(set_transformer_input_convert_pass inference) pass_library(remove_padding_recover_padding_pass inference) pass_library(delete_remove_padding_recover_padding_pass inference) + pass_library(layernorm_shift_partition_fuse_pass inference) endif() if(WITH_GPU OR WITH_ROCM) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index d8ec95ee85c98e2dc7ea03fcb1e8eaa832b4fd0e..0d63ce21211318ac208352c2efa908c8b608e067 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3502,6 +3502,106 @@ PDNode *patterns::AddSupportInt8::operator()() { return quant_out; } +PDNode *patterns::LayernormShiftPartitionPattern::operator()() { + auto layer_norm_op = + pattern->NewNode(layer_norm_op_repr()) + ->assert_is_op("layer_norm") + ->assert_more([&](Node *node) { + return node->Op()->HasAttr("begin_norm_axis") && + (PADDLE_GET_CONST( + int, node->Op()->GetAttr("begin_norm_axis")) == 2); + }); + auto layer_norm_in = pattern->NewNode(layer_norm_in_repr()) + ->AsInput() + ->assert_is_op_input("layer_norm", "X"); + auto layer_norm_bias = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_op_input("layer_norm", "Bias"); + auto layer_norm_scale = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_op_input("layer_norm", "Scale"); + auto layer_norm_out = pattern->NewNode(layer_norm_out_repr()) + ->AsIntermediate() + ->assert_is_op_input("reshape2", "X") + ->assert_is_op_output("layer_norm", "Y"); + auto reshape1_op = + pattern->NewNode(reshape1_op_repr()) + ->assert_is_op("reshape2") + ->assert_more([&](Node *node) { + return node->Op()->HasAttr("shape") && + (PADDLE_GET_CONST(std::vector, + node->Op()->GetAttr("shape")) + .size() == 4); + }); + auto reshape1_out = pattern->NewNode(reshape1_out_repr()) + ->AsIntermediate() + ->assert_is_op_input("reshape2", "X") + ->assert_is_op_output("reshape2", "Out"); + auto reshape2_op = + pattern->NewNode(reshape2_op_repr()) + ->assert_is_op("reshape2") + ->assert_more([&](Node *node) { + return node->Op()->HasAttr("shape") && + (PADDLE_GET_CONST(std::vector, + node->Op()->GetAttr("shape")) + .size() == 6); + }); + auto reshape2_out = pattern->NewNode(reshape2_out_repr()) + ->AsIntermediate() + ->assert_is_op_input("transpose2", "X") + ->assert_is_op_output("reshape2", "Out"); + auto transpose_op = + pattern->NewNode(transpose_op_repr()) + ->assert_is_op("transpose2") + ->assert_more([&](Node *node) { + if (!node->Op()->HasAttr("axis")) return false; + std::vector axis = + PADDLE_GET_CONST(std::vector, node->Op()->GetAttr("axis")); + if (axis.size() != 6) return false; + const std::vector axis_cmp{0, 1, 3, 2, 4, 5}; + return std::equal(axis.begin(), axis.end(), axis_cmp.begin()); + }); + auto transpose_out = pattern->NewNode(transpose_out_repr()) + ->AsIntermediate() + ->assert_is_op_input("reshape2", "X") + ->assert_is_op_output("transpose2", "Out"); + auto reshape3_op = + pattern->NewNode(reshape3_op_repr()) + ->assert_is_op("reshape2") + ->assert_more([&](Node *node) { + return node->Op()->HasAttr("shape") && + (PADDLE_GET_CONST(std::vector, + node->Op()->GetAttr("shape")) + .size() == 4); + }); + auto reshape3_out = pattern->NewNode(reshape3_out_repr()) + ->AsIntermediate() + ->assert_is_op_input("reshape2", "X") + ->assert_is_op_output("reshape2", "Out"); + auto reshape4_op = + pattern->NewNode(reshape4_op_repr()) + ->assert_is_op("reshape2") + ->assert_more([&](Node *node) { + return node->Op()->HasAttr("shape") && + (PADDLE_GET_CONST(std::vector, + node->Op()->GetAttr("shape")) + .size() == 3); + }); + auto reshape4_out = pattern->NewNode(reshape4_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->AsOutput(); + + layer_norm_op->LinksFrom({layer_norm_in, layer_norm_bias, layer_norm_scale}) + .LinksTo({layer_norm_out}); + reshape1_op->LinksFrom({layer_norm_out}).LinksTo({reshape1_out}); + reshape2_op->LinksFrom({reshape1_out}).LinksTo({reshape2_out}); + transpose_op->LinksFrom({reshape2_out}).LinksTo({transpose_out}); + reshape3_op->LinksFrom({transpose_out}).LinksTo({reshape3_out}); + reshape4_op->LinksFrom({reshape3_out}).LinksTo({reshape4_out}); + + return reshape4_out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index f97659038262c5741fc9ec34bf5ba19ed9f1e7aa..b2eb740b9acaf77f0e68196ee3f30b15c1dc2623 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1911,6 +1911,34 @@ struct LayerNorm : public PatternBase { PATTERN_DECL_NODE(shift_out); }; +// +// \brief Pattern looking for subgraph representing layernorm_shift_partition +// operation. +// +struct LayernormShiftPartitionPattern : public PatternBase { + LayernormShiftPartitionPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "layernorm_shift_partition") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(layer_norm_in); + PATTERN_DECL_NODE(layer_norm_op); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(reshape1_op); + PATTERN_DECL_NODE(reshape1_out); + PATTERN_DECL_NODE(reshape2_op); + PATTERN_DECL_NODE(reshape2_out); + PATTERN_DECL_NODE(transpose_op); + PATTERN_DECL_NODE(transpose_out); + PATTERN_DECL_NODE(reshape3_op); + PATTERN_DECL_NODE(reshape3_out); + PATTERN_DECL_NODE(reshape4_op); + PATTERN_DECL_NODE(reshape4_out); +}; + // Add support int8 flag struct AddSupportInt8 : public PatternBase { AddSupportInt8(PDPattern* pattern, const std::string& name_scope) diff --git a/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.cc b/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..9353f4b3efd848c46fbea4458b0379ca32d92d55 --- /dev/null +++ b/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.cc @@ -0,0 +1,217 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.h" + +#include +#include + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Node; + +LayerNormShiftPartitionFusePass::LayerNormShiftPartitionFusePass() { + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Variance") + .IsTensor() + .IsOptional() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumEQ(2) + .End(); + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); +} + +void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, + platform::errors::InvalidArgument( + "The input graph of LayerNormShiftPartitionFusePass should not be " + "nullptr.")); + + FusePassBase::Init(scope_name_, graph); + + GraphPatternDetector gpd; + patterns::LayernormShiftPartitionPattern shift_patition_pattern( + gpd.mutable_pattern(), scope_name_); + shift_patition_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "layernorm_shift_partition_fuse in op compat failed."; + return; + } + + VLOG(4) << "layernorm_shift_partition_fuse pass"; + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_in, layer_norm_in, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_op, layer_norm_op, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_bias, layer_norm_bias, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_out, layer_norm_out, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape1_op, reshape1_op, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape1_out, reshape1_out, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape2_out, reshape2_out, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose_op, transpose_op, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + transpose_out, transpose_out, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape3_op, reshape3_op, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape3_out, reshape3_out, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape4_op, reshape4_op, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + reshape4_out, reshape4_out, shift_patition_pattern); + + std::vector shape_atr1 = + PADDLE_GET_CONST(std::vector, reshape1_op->Op()->GetAttr("shape")); + std::vector shape_atr2 = + PADDLE_GET_CONST(std::vector, reshape2_op->Op()->GetAttr("shape")); + std::vector shape_atr3 = + PADDLE_GET_CONST(std::vector, reshape3_op->Op()->GetAttr("shape")); + std::vector shape_atr4 = + PADDLE_GET_CONST(std::vector, reshape4_op->Op()->GetAttr("shape")); + + // emb dim should be same + if (!((shape_atr1.back() == shape_atr2.back()) && + (shape_atr2.back() == shape_atr3.back()) && + (shape_atr3.back() == shape_atr4.back()))) { + return; + } + + if (shape_atr1[1] != shape_atr1[2]) { + return; + } + int input_resolution = shape_atr1[1]; + + if (shape_atr3[1] != shape_atr3[2]) { + return; + } + int window_size = shape_atr2[2]; + if (window_size < 0 || input_resolution < 0) { + return; + } + + OpDesc new_op_desc; + new_op_desc.SetType("layernorm_shift_partition"); + new_op_desc.SetInput("X", {layer_norm_in->Name()}); + new_op_desc.SetInput("Bias", {layer_norm_bias->Name()}); + new_op_desc.SetInput("Scale", {layer_norm_scale->Name()}); + new_op_desc.SetOutput("Y", {reshape4_out->Name()}); + new_op_desc.SetAttr("epsilon", layer_norm_op->Op()->GetAttr("epsilon")); + new_op_desc.SetAttr("begin_norm_axis", + layer_norm_op->Op()->GetAttr("begin_norm_axis")); + new_op_desc.SetAttr("window_size", window_size); + new_op_desc.SetAttr("input_resolution", input_resolution); + new_op_desc.Flush(); + + auto* layernorm_shift_partition = graph->CreateOpNode(&new_op_desc); + + IR_NODE_LINK_TO(layer_norm_in, layernorm_shift_partition); + IR_NODE_LINK_TO(layer_norm_bias, layernorm_shift_partition); + IR_NODE_LINK_TO(layer_norm_scale, layernorm_shift_partition); + IR_NODE_LINK_TO(layernorm_shift_partition, reshape4_out); + GraphSafeRemoveNodes(graph, + {layer_norm_op, + layer_norm_out, + reshape1_op, + reshape1_out, + reshape2_op, + reshape2_out, + transpose_op, + transpose_out, + reshape3_op, + reshape3_out, + reshape4_op}); + ++found_count; + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(layernorm_shift_partition_fuse_pass, + paddle::framework::ir::LayerNormShiftPartitionFusePass); +REGISTER_PASS_CAPABILITY(layernorm_shift_partition_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("transpose2", 0) + .EQ("reshape2", 0)); diff --git a/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.h b/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..7c3d435ef430447b3cf44709af7657e08684abbf --- /dev/null +++ b/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.h @@ -0,0 +1,54 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +// | +// layer_norm +// | +// reshape2 +// | +// reshape2 | +// | fuse layernorm_shift_patition +// transpose2 -> | +// | other_op +// reshape2 +// | +// reshape2 +// | +// other_op +class LayerNormShiftPartitionFusePass : public FusePassBase { + public: + LayerNormShiftPartitionFusePass(); + virtual ~LayerNormShiftPartitionFusePass() {} + + protected: + void ApplyImpl(ir::Graph *graph) const override; + + private: + const std::string scope_name_{"layernorm_shift_partition_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 17f1396ce8d6a84658b704e4f21f4c6ae50db0d6..fbc2830aff6148ef7ccc011e51b4c4a8dab13813 100755 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2184,6 +2184,7 @@ USE_TRT_CONVERTER(sum) USE_TRT_CONVERTER(shape) USE_TRT_CONVERTER(fill_constant) USE_TRT_CONVERTER(fused_token_prune) +USE_TRT_CONVERTER(layernorm_shift_partition) #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) USE_TRT_CONVERTER(sparse_fc) USE_TRT_CONVERTER(sparse_multihead_matmul) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 0b7982950f43d314ddc105d48d62c5ad555b19b7..2b5cb6dd050a6ee0a0410b44d3a5bbd3c40cd469 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -105,6 +105,7 @@ const std::vector kTRTSubgraphPasses({ "trt_skip_layernorm_fuse_pass", // "preln_skip_layernorm_fuse_pass", // "preln_residual_bias_fuse_pass", // + "layernorm_shift_partition_fuse_pass", // // "set_transformer_input_convert_pass", // "conv_bn_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 1bb7d3c5b1f5ab6033a2636cefc85ead82ece920..ce95363b72d0b3d81cfa9bef464b8160fc7dd488 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -75,7 +75,8 @@ list( sum_op.cc shape_op.cc fill_constant_op.cc - fused_token_prune_op.cc) + fused_token_prune_op.cc + layernorm_shift_partition_op.cc) if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc) diff --git a/paddle/fluid/inference/tensorrt/convert/layernorm_shift_partition_op.cc b/paddle/fluid/inference/tensorrt/convert/layernorm_shift_partition_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..15f2663ce59bdc6fb99443bb5a6c7f197aba0ad4 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/layernorm_shift_partition_op.cc @@ -0,0 +1,108 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class LayerNormShiftPartitionOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert a fluid layernorm_shift_partition op to tensorrt " + "layernorm_shift_partition plugin"; + framework::OpDesc op_desc(op, nullptr); + + auto* X = engine_->GetITensor(op_desc.Input("X").front()); + auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front()); + auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front()); + const int begin_norm_axis = + op_desc.HasAttr("begin_norm_axis") + ? PADDLE_GET_CONST(int, op_desc.GetAttr("begin_norm_axis")) + : 1; + const float eps = op_desc.HasAttr("epsilon") + ? PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")) + : 1e-5f; + const int window_size = + PADDLE_GET_CONST(int, op_desc.GetAttr("window_size")); + const int input_resolution = + PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution")); + // int shift_size = window_size / 2; + // shift_size = (input_resolution <= window_size) ? 0 : shift_size; + int shift_size = 0; + + PADDLE_ENFORCE_NOT_NULL( + Bias_v, + platform::errors::InvalidArgument( + "Input(Bias) of layer_norm should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + Scale_v, + platform::errors::InvalidArgument( + "Input(Scale) of layer_norm should not be null.")); + PADDLE_ENFORCE_EQ( + begin_norm_axis, + 2, + platform::errors::InvalidArgument( + "The begin_norm_axis of LayernormShiftPartition should be %d", + begin_norm_axis)); + + auto* Bias_t = Bias_v->GetMutable(); + auto* Scale_t = Scale_v->GetMutable(); + + auto bias_weight = + engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *Bias_t); + auto scale_weight = + engine_->GetFp32TrtWeight(op_desc.Input("Scale").front(), *Scale_t); + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + PADDLE_ENFORCE_EQ(bias_weight.get().count, + scale_weight.get().count, + platform::errors::InvalidArgument( + "The num between bias_weight and cale_weight should " + "be equal. (%d vs %d)", + bias_weight.get().count, + scale_weight.get().count)); + nvinfer1::ILayer* layernorm_layer = nullptr; + if (engine_->with_dynamic_shape()) { + plugin::LayernormShiftPartitionPluginDynamic* plugin = + new plugin::LayernormShiftPartitionPluginDynamic( + static_cast(scale_weight.get().values), + static_cast(bias_weight.get().values), + bias_weight.get().count, + shift_size, + window_size, + input_resolution, + eps, + with_fp16); + layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "LayernormShiftPartition TRT Plugin should run in dynamic shape.")); + } + + auto output_name = op_desc.Output("Y").front(); + RreplenishLayerAndOutput( + layernorm_layer, "layernorm_shift_partition", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(layernorm_shift_partition, + LayerNormShiftPartitionOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 8243bb96205eb80f7952540ec5183285e7337316..6286010a03b3cb67056a4440014b0892b1a5d09c 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -176,7 +176,8 @@ struct SimpleOpTypeSetTeller : public Teller { "sum", "shape", "squeeze2", - "unsqueeze2"}; + "unsqueeze2", + "layernorm_shift_partition"}; std::unordered_set teller_set{ "mul", "matmul", @@ -286,7 +287,8 @@ struct SimpleOpTypeSetTeller : public Teller { "shape", "squeeze2", "unsqueeze2", - "fused_token_prune"}; + "fused_token_prune", + "layernorm_shift_partition"}; }; bool OpTeller::Tell(const framework::ir::Node* node, @@ -2246,6 +2248,14 @@ bool OpTeller::Tell(const framework::ir::Node* node, #endif } + if (op_type == "layernorm_shift_partition") { + if (!with_dynamic_shape) { + VLOG(3) << "the layernorm_shift_partition does not support " + "static shape yet"; + return false; + } + } + if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index b41823d9186f6ec817fc7f71417933726254fafa..f602714f21150be9723f5ebac0071f17bc566b2f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -31,7 +31,8 @@ list( recover_padding_plugin.cu c_allreduce_op_plugin.cu preln_residual_bias_plugin.cu - fused_token_prune_op_plugin.cu) + fused_token_prune_op_plugin.cu + layernorm_shift_partition_op.cu) if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) list(APPEND TRT_FILES spmm_plugin.cu) diff --git a/paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.cu b/paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ce38a545efe7002b6f81406095894dcb3a2f544c --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.cu @@ -0,0 +1,665 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "glog/logging.h" +#include "paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.h" +#include "paddle/phi/kernels/layer_norm_kernel.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#define FINAL_MASK 0xffffffff + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + + return val; +} + +template +__global__ void layernorm_shift_partition(T *out, + const T *input, + const T *gamma, + const T *beta, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps) { + int tid = threadIdx.x; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = + (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) + : blockIdx.y; + const int shifted_W_idx = + (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) + : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; + const int stride_of_window_H = W / window_size; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + + (shifted_W_idx % window_size); + const int output_bid = + batch_offset + window_idx * window_size * window_size + idx_in_window; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + + float local_out = + (tid < n) ? static_cast(__ldg(input + bid * n + tid)) : 0.0f; + + mean = blockReduceSum(local_out); + if (threadIdx.x == 0) { + s_mean = mean / n; + } + __syncthreads(); + + float diff = (tid < n) ? (local_out - s_mean) : 0.0f; + variance = blockReduceSum(diff * diff); + if (threadIdx.x == 0) { + s_variance = variance / n + eps; + } + __syncthreads(); + + if (tid < n) { + out[output_bid * n + tid] = + (T)(((local_out - s_mean) * rsqrtf(s_variance)) * + static_cast(__ldg(&gamma[tid])) + + static_cast(__ldg(&beta[tid]))); + } +} + +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +template <> +__global__ void layernorm_shift_partition(half2 *out_ptr, + const half2 *input_ptr, + const half2 *gamma_ptr, + const half2 *beta_ptr, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps) { + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = + (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) + : blockIdx.y; + const int shifted_W_idx = + (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) + : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; + const int stride_of_window_H = W / window_size; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + + (shifted_W_idx % window_size); + const int output_bid = + batch_offset + window_idx * window_size * window_size + idx_in_window; + int tid = threadIdx.x; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + float2 local_out_fp2; + + float local_out = 0.0f; + int id = bid * n + tid; + if (tid < n) { + local_out_fp2 = __half22float2(__ldg(input_ptr + id)); + local_out += local_out_fp2.x; + local_out += local_out_fp2.y; + } + + mean = blockReduceSum(local_out); + if (threadIdx.x == 0) { + s_mean = mean / (n * 2); + } + __syncthreads(); + + if (tid < n) { + variance = (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); + variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / (n * 2) + eps); + } + __syncthreads(); + + if (tid < n) { + float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid])); + float2 beta_val = __half22float2(__ldg(&beta_ptr[tid])); + local_out_fp2.x = + (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; + local_out_fp2.y = + (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; + out_ptr[output_bid * n + tid] = __float22half2_rn(local_out_fp2); + } +} +#endif + +#define kITE 4 +template +__global__ void layernorm_shift_partition_v2(T *out, + const T *__restrict input, + const T *__restrict gamma, + const T *__restrict beta, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps) { + // constexpr int kITE = 4; + const int tid = threadIdx.x; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = + (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) + : blockIdx.y; + const int shifted_W_idx = + (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) + : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; + const int stride_of_window_H = W / window_size; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + + (shifted_W_idx % window_size); + const int output_bid = + batch_offset + window_idx * window_size * window_size + idx_in_window; + const int offset = bid * n; + const int output_offset = output_bid * n; + + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + float local_out[kITE]; + + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + local_out[i] = static_cast(__ldg(input + offset + col_id)); + sum += local_out[i]; + } + } + + mean = blockReduceSum(sum); + if (tid == 0) { + s_mean = mean / n; + } + __syncthreads(); + + float var = 0.0f; +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + float diff = local_out[i] - s_mean; + local_out[i] = diff; + var += diff * diff; + } + } + + variance = blockReduceSum(var); + if (tid == 0) { + s_variance = rsqrtf(variance / n + eps); + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + out[output_offset + col_id] = + (T)(local_out[i] * s_variance * + static_cast(__ldg(&gamma[col_id])) + + static_cast(__ldg(&beta[col_id]))); + } + } +} + +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +template <> +__global__ void layernorm_shift_partition_v2(half2 *out_ptr, + const half2 *__restrict input_ptr, + const half2 *__restrict gamma_ptr, + const half2 *__restrict beta_ptr, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps) { + // constexpr int ite = 4; + const int tid = threadIdx.x; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = + (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) + : blockIdx.y; + const int shifted_W_idx = + (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) + : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; + const int stride_of_window_H = W / window_size; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + + (shifted_W_idx % window_size); + const int output_bid = + batch_offset + window_idx * window_size * window_size + idx_in_window; + const int offset = bid * n; + const int output_offset = output_bid * n; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + half2 local_out_half2[kITE]; + const half2 zero = {static_cast(0.0f), static_cast(0.0f)}; + + // float sum = 0.0f; + half2 sum = __float2half2_rn(0.0f); +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + local_out_half2[i] = __ldg(input_ptr + offset + col_id); + sum += local_out_half2[i]; + } + } + + mean = blockReduceSum(static_cast(sum.x + sum.y)); + if (threadIdx.x == 0) { + s_mean = mean / (n * 2); + } + __syncthreads(); + + float var = 0.0f; + half2 s_mean_2 = __float2half2_rn(s_mean); +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + local_out_half2[i] = local_out_half2[i] - s_mean_2; + float v1 = static_cast(local_out_half2[i].x); + float v2 = static_cast(local_out_half2[i].y); + var += v1 * v1 + v2 * v2; + } + } + + variance = blockReduceSum(var); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / (n * 2) + eps); + } + __syncthreads(); + + half2 s_var_2 = __float2half2_rn(s_variance); +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + out_ptr[output_offset + col_id] = + local_out_half2[i] * s_var_2 * __ldg(&gamma_ptr[col_id]) + + __ldg(&beta_ptr[col_id]); + } + } +} +#endif + +template +void invokeLayernormShiftPartition(T *out, + const T *input, + const T *gamma, + const T *beta, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps, + cudaStream_t stream) { + dim3 grid(W, H, batch); + int blockSize = (n + 31) / 32 * 32; + if (blockSize >= 768) { + blockSize = ((blockSize / 4) + 31) / 32 * 32; + layernorm_shift_partition_v2<<>>( + out, input, gamma, beta, batch, H, W, n, shift_size, window_size, eps); + } else { + layernorm_shift_partition<<>>( + out, input, gamma, beta, batch, H, W, n, shift_size, window_size, eps); + } +} + +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) +template <> +void invokeLayernormShiftPartition(half *out, + const half *input, + const half *gamma, + const half *beta, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps, + cudaStream_t stream) { + dim3 grid(W, H, batch); + int blockSize = n / 2; + blockSize = (blockSize + 31) / 32 * 32; + + if ((batch * H * W >= 512 && blockSize >= 768) || blockSize > 1024) { + blockSize = ((blockSize / 4) + 31) / 32 * 32; + layernorm_shift_partition_v2<<>>( + reinterpret_cast(out), + (const half2 *)input, + (const half2 *)gamma, + (const half2 *)beta, + batch, + H, + W, + n / 2, + shift_size, + window_size, + eps); + } else { + layernorm_shift_partition<<>>( + reinterpret_cast(out), + (const half2 *)input, + (const half2 *)gamma, + (const half2 *)beta, + batch, + H, + W, + n / 2, + shift_size, + window_size, + eps); + } +} +#endif + +template +static void convertAndCopy(const std::vector &host, T *dev) { + T *host_ptr = new T[host.size()]; + std::transform(host.begin(), host.end(), host_ptr, [](float x) { + return static_cast(x); + }); + cudaMemcpy(dev, host_ptr, sizeof(T) * host.size(), cudaMemcpyHostToDevice); + delete host_ptr; +} + +void LayernormShiftPartitionPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT {} + +LayernormShiftPartitionPluginDynamic::LayernormShiftPartitionPluginDynamic( + const float *gamma, + const float *beta, + const int param_num, + int shift_size, + int window_size, + int input_resolution, + float eps, + bool with_fp16, + std::shared_ptr gamma_dev, + std::shared_ptr beta_dev) + : with_fp16_(with_fp16), + window_size_(window_size), + shift_size_(shift_size), + input_resolution_(input_resolution), + eps_(eps), + param_num_(param_num), + gamma_dev_(gamma_dev), + beta_dev_(beta_dev) { + beta_.resize(param_num); + gamma_.resize(param_num); + std::copy(gamma, gamma + param_num, gamma_.data()); + std::copy(beta, beta + param_num, beta_.data()); + int type_size = with_fp16 ? sizeof(half) : sizeof(float); + if (gamma_dev_ == nullptr) { + void *p; + cudaMalloc(reinterpret_cast(&p), param_num_ * type_size); + gamma_dev_.reset(p, [](void *ptr) { cudaFree(ptr); }); + if (with_fp16) + convertAndCopy(gamma_, reinterpret_cast(p)); + else + convertAndCopy(gamma_, reinterpret_cast(p)); + } + if (beta_dev_ == nullptr) { + void *p; + cudaMalloc(reinterpret_cast(&p), param_num_ * type_size); + beta_dev_.reset(p, [](void *ptr) { cudaFree(ptr); }); + if (with_fp16) + convertAndCopy(beta_, reinterpret_cast(p)); + else + convertAndCopy(beta_, reinterpret_cast(p)); + } +} + +LayernormShiftPartitionPluginDynamic::LayernormShiftPartitionPluginDynamic( + void const *serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &beta_); + DeserializeValue(&serialData, &serialLength, &gamma_); + DeserializeValue(&serialData, &serialLength, ¶m_num_); + DeserializeValue(&serialData, &serialLength, &with_fp16_); + DeserializeValue(&serialData, &serialLength, &shift_size_); + DeserializeValue(&serialData, &serialLength, &window_size_); + DeserializeValue(&serialData, &serialLength, &input_resolution_); + DeserializeValue(&serialData, &serialLength, &eps_); + int type_size = with_fp16_ ? sizeof(half) : sizeof(float); + { + void *p; + cudaMalloc(reinterpret_cast(&p), param_num_ * type_size); + gamma_dev_.reset(p, [](void *ptr) { cudaFree(ptr); }); + if (with_fp16_) + convertAndCopy(gamma_, reinterpret_cast(p)); + else + convertAndCopy(gamma_, reinterpret_cast(p)); + } + { + void *p; + cudaMalloc(reinterpret_cast(&p), param_num_ * type_size); + beta_dev_.reset(p, [](void *ptr) { cudaFree(ptr); }); + if (with_fp16_) + convertAndCopy(beta_, reinterpret_cast(p)); + else + convertAndCopy(beta_, reinterpret_cast(p)); + } +} + +bool LayernormShiftPartitionPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc *in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_NOT_NULL( + in_out, + platform::errors::InvalidArgument("The input of LayernormShiftPartition " + "plugin shoule not be nullptr.")); + PADDLE_ENFORCE_LT( + pos, + nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, + nb_inputs + nb_outputs)); + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + if (with_fp16_) { + return in.type == nvinfer1::DataType::kHALF && + in.format == nvinfer1::TensorFormat::kLINEAR; + } else { + return in.type == nvinfer1::DataType::kFLOAT && + in.format == nvinfer1::TensorFormat::kLINEAR; + } + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType LayernormShiftPartitionPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType *input_types, + int nb_inputs) const TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ( + index, + 0, + platform::errors::InvalidArgument( + "The LayernormShiftPartition only has one input, so the " + "index value should be 0, but get %d.", + index)); + return input_types[0]; +} + +nvinfer1::DimsExprs LayernormShiftPartitionPluginDynamic::getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs *inputs, + int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ( + output_index, + 0, + platform::errors::InvalidArgument( + "There is only one output of the LayernormShiftPartition, " + "so the index should be zero," + "but it's (%d)", + output_index)); + PADDLE_ENFORCE_EQ( + nb_inputs, + 1, + platform::errors::InvalidArgument( + "The Input of the LayernormShiftPartition should be 1, but we found " + "it has (%d) inputs", + nb_inputs)); + + nvinfer1::DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = expr_builder.operation( + nvinfer1::DimensionOperation::kFLOOR_DIV, + *expr_builder.operation(nvinfer1::DimensionOperation::kPROD, + *inputs[0].d[0], + *inputs[0].d[1]), + *expr_builder.constant(window_size_ * window_size_)); + ret.d[1] = expr_builder.constant(window_size_ * window_size_); + ret.d[2] = inputs[0].d[2]; + return ret; +} + +int LayernormShiftPartitionPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, + const void *const *inputs, + void *const *outputs, + void *workspace, + cudaStream_t stream) TRT_NOEXCEPT { + const auto &input_dims = input_desc[0].dims; + auto input_type = input_desc[0].type; + int batch = input_dims.d[0]; + int emb_dim = input_dims.d[2]; + PADDLE_ENFORCE_EQ( + input_resolution_ * input_resolution_, + input_dims.d[1], + platform::errors::InvalidArgument( + "The LayernormShiftPartition‘s input_resolution is wrong (%d)", + input_dims.d[1])); + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(3) << "TRT Plugin DataType selected. LayernormShiftPartition-->fp32"; + invokeLayernormShiftPartition( + reinterpret_cast(outputs[0]), + reinterpret_cast(inputs[0]), + reinterpret_cast(gamma_dev_.get()), + reinterpret_cast(beta_dev_.get()), + batch, + input_resolution_, + input_resolution_, + emb_dim, + shift_size_, + window_size_, + eps_, + stream); + } else if (input_type == nvinfer1::DataType::kHALF) { + VLOG(3) << "TRT Plugin DataType selected. LayernormShiftPartition-->half"; + invokeLayernormShiftPartition( + reinterpret_cast(outputs[0]), + reinterpret_cast(inputs[0]), + reinterpret_cast(gamma_dev_.get()), + reinterpret_cast(beta_dev_.get()), + batch, + input_resolution_, + input_resolution_, + emb_dim, + shift_size_, + window_size_, + eps_, + stream); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The LayerNorm TRT Plugin's input type should be float or half.")); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.h b/paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.h new file mode 100644 index 0000000000000000000000000000000000000000..421a73af465777cd0c4ab0a99e45dcbd6f5e6777 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.h @@ -0,0 +1,156 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class LayernormShiftPartitionPluginDynamic : public DynamicPluginTensorRT { + public: + LayernormShiftPartitionPluginDynamic( + const float* gamma, + const float* beta, + const int param_num, + int shift_size, + int window_size, + int input_resolution, + float eps, + bool with_fp16, + std::shared_ptr gamma_dev = nullptr, + std::shared_ptr beta_dev = nullptr); + + LayernormShiftPartitionPluginDynamic(void const* serialData, + size_t serialLength); + + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + return new LayernormShiftPartitionPluginDynamic(gamma_.data(), + beta_.data(), + beta_.size(), + shift_size_, + window_size_, + input_resolution_, + eps_, + with_fp16_, + gamma_dev_, + beta_dev_); + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "layernorm_shift_partition_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + int initialize() TRT_NOEXCEPT override { return 0; } + + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(beta_) + SerializedSize(gamma_) + + SerializedSize(param_num_) + SerializedSize(with_fp16_) + + SerializedSize(shift_size_) + SerializedSize(window_size_) + + SerializedSize(input_resolution_) + SerializedSize(eps_); + } + + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, beta_); + SerializeValue(&buffer, gamma_); + SerializeValue(&buffer, param_num_); + SerializeValue(&buffer, with_fp16_); + SerializeValue(&buffer, shift_size_); + SerializeValue(&buffer, window_size_); + SerializeValue(&buffer, input_resolution_); + SerializeValue(&buffer, eps_); + } + + nvinfer1::DimsExprs getOutputDimensions(int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) + TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + private: + bool with_fp16_; + std::vector gamma_; + std::vector beta_; + int window_size_; + int shift_size_; + int input_resolution_; + int param_num_; + float eps_; + std::shared_ptr gamma_dev_; + std::shared_ptr beta_dev_; +}; + +class LayernormShiftPartitionPluginDynamicCreator + : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "layernorm_shift_partition_dynamic"; + } + + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + return new LayernormShiftPartitionPluginDynamic(serial_data, serial_length); + } +}; + +REGISTER_TRT_PLUGIN_V2(LayernormShiftPartitionPluginDynamicCreator); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_layernorm_shift_partition_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_layernorm_shift_partition_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..2d5dd9fe4bd5b6e7f2281a95979f4dc10e1ae0e4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_layernorm_shift_partition_pass.py @@ -0,0 +1,208 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestLayernormShiftPartitionPass(PassAutoScanTest): + """ + | + layer_norm + | + reshape2 + | + reshape2 + | + transpose2 + | + reshape2 + | + reshape2 + | + """ + + def sample_predictor_configs(self, program_config): + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + config.set_trt_dynamic_shape_info({ + "input_data": [1, 9, 96], + }, { + "input_data": [4, 3136, 768], + }, { + "input_data": [1, 784, 384], + }) + yield config, ['layernorm_shift_partition'], (1e-5, 1e-5) + + def sample_program_config(self, draw): + axis = [0, 1, 3, 2, 4, 5] + epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001)) + # begin_norm_axis has to be 2 + begin_norm_axis = 2 + batch_size = draw(st.integers(min_value=1, max_value=4)) + + window_size = draw(st.sampled_from([3, 5, 7])) + move_shape = draw(st.integers(min_value=1, max_value=8)) + dim = draw(st.sampled_from([96, 192, 384, 768])) + + def generate_input(attrs): + return np.random.random( + [attrs[1]["batch_size"], + *attrs[1]["input_dim"]]).astype(np.float32) + + def generate_weight(attrs): + return np.random.random(attrs[1]['input_dim'][-1]).astype( + np.float32) + + attrs = [{ + 'begin_norm_axis': begin_norm_axis, + 'epsilon': epsilon, + }, { + 'batch_size': batch_size, + 'input_dim': [(window_size * move_shape)**2, dim], + }, { + 'axis': axis, + 'input_resolution': window_size * move_shape, + 'move_shape': move_shape, + 'window_size': window_size, + }] + + layer_norm_op = OpConfig(type="layer_norm", + inputs={ + "X": ["input_data"], + "Bias": ["layer_norm_bias"], + "Scale": ["layer_norm_scale"] + }, + outputs={ + "Y": ["layer_norm_output1"], + "Mean": ["layer_norm_output2"], + "Variance": ["layer_norm_output3"] + }, + attrs={ + "begin_norm_axis": + attrs[0]["begin_norm_axis"], + "epsilon": attrs[0]["epsilon"], + }) + reshape_op2 = OpConfig(type="reshape2", + inputs={ + "X": ["layer_norm_output1"], + }, + outputs={ + "Out": ["reshape_output2"], + "XShape": ["reshape_output2_xshape"], + }, + attrs={ + 'shape': [ + -1, attrs[2]["input_resolution"], + attrs[2]["input_resolution"], + attrs[1]["input_dim"][-1] + ] + }) + reshape_op3 = OpConfig(type="reshape2", + inputs={ + "X": ["reshape_output2"], + }, + outputs={ + "Out": ["reshape_output3"], + "XShape": ["reshape_output3_xshape"], + }, + attrs={ + 'shape': [ + -1, attrs[2]["move_shape"], + attrs[2]["window_size"], + attrs[2]["move_shape"], + attrs[2]["window_size"], + attrs[1]["input_dim"][-1] + ] + }) + transpose_op4 = OpConfig(type='transpose2', + inputs={ + "X": ["reshape_output3"], + }, + outputs={"Out": ["transpose_output4"]}, + attrs={"axis": attrs[2]['axis']}) + reshape_op5 = OpConfig(type="reshape2", + inputs={ + "X": ["transpose_output4"], + }, + outputs={ + "Out": ["reshape_output5"], + "XShape": ["reshape_output5_xshape"], + }, + attrs={ + 'shape': [ + -1, attrs[2]["window_size"], + attrs[2]["window_size"], + attrs[1]["input_dim"][-1] + ] + }) + reshape_op6 = OpConfig( + type="reshape2", + inputs={ + "X": ["reshape_output5"], + }, + outputs={ + "Out": ["reshape_output6"], + "XShape": ["reshape_output6_xshape"], + }, + attrs={ + 'shape': + [-1, attrs[2]["window_size"]**2, attrs[1]["input_dim"][-1]] + }) + + program_config = ProgramConfig( + ops=[ + layer_norm_op, reshape_op2, reshape_op3, transpose_op4, + reshape_op5, reshape_op6 + ], + weights={ + "layer_norm_bias": + TensorConfig(data_gen=partial(generate_weight, attrs)), + "layer_norm_scale": + TensorConfig(data_gen=partial(generate_weight, attrs)) + }, + inputs={ + "input_data": + TensorConfig(data_gen=partial(generate_input, attrs)), + }, + outputs=["reshape_output6"]) + + return program_config + + def test(self): + self.run_and_statis(quant=False, + max_examples=20, + passes=["layernorm_shift_partition_fuse_pass"], + max_duration=250, + min_success_num=20) + + +if __name__ == "__main__": + unittest.main()