diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 08d5e23b6f445f53cc364d365c4fd68a860ed01d..c6337a5a304e1970c31089e27afbe95d07fa8d4c 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -121,6 +121,7 @@ if(WITH_TENSORRT) pass_library(trt_map_matmul_to_mul_pass inference) pass_library(trt_multihead_matmul_fuse_pass inference) pass_library(trt_skip_layernorm_fuse_pass inference) + pass_library(merge_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference) pass_library(set_transformer_input_convert_pass inference) pass_library(remove_padding_recover_padding_pass inference) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 0492651c6000dadc9c1ca396489fe284251b442b..72b19c1dd527b19df4e59ae5bfb2aa0be1830fb7 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3638,6 +3638,92 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() { return reshape4_out; } +PDNode *patterns::MergeLayernormPattern::operator()(PDNode *in) { + in->AsInput(); + auto reshape2_00_op = + pattern->NewNode(reshape2_00_op_repr())->assert_is_op("reshape2"); + auto reshape2_00_out = pattern->NewNode(reshape2_00_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("strided_slice", "Input") + ->AsIntermediate(); + auto strided_slice_10_op = pattern->NewNode(strided_slice_10_op_repr()) + ->assert_is_op("strided_slice"); + auto strided_slice_10_out = pattern->NewNode(strided_slice_10_out_repr()) + ->assert_is_op_output("strided_slice", "Out") + ->assert_is_op_nth_input("concat", "X", 0) + ->AsIntermediate(); + auto strided_slice_11_op = pattern->NewNode(strided_slice_11_op_repr()) + ->assert_is_op("strided_slice"); + auto strided_slice_11_out = pattern->NewNode(strided_slice_11_out_repr()) + ->assert_is_op_output("strided_slice", "Out") + ->assert_is_op_nth_input("concat", "X", 1) + ->AsIntermediate(); + auto strided_slice_12_op = pattern->NewNode(strided_slice_12_op_repr()) + ->assert_is_op("strided_slice"); + auto strided_slice_12_out = pattern->NewNode(strided_slice_12_out_repr()) + ->assert_is_op_output("strided_slice", "Out") + ->assert_is_op_nth_input("concat", "X", 2) + ->AsIntermediate(); + auto strided_slice_13_op = pattern->NewNode(strided_slice_13_op_repr()) + ->assert_is_op("strided_slice"); + auto strided_slice_13_out = pattern->NewNode(strided_slice_13_out_repr()) + ->assert_is_op_output("strided_slice", "Out") + ->assert_is_op_nth_input("concat", "X", 3) + ->AsIntermediate(); + auto concat_20_op = pattern->NewNode(concat_20_op_repr()) + ->assert_is_op("concat") + ->assert_has_n_inputs(4); + auto concat_20_out = pattern->NewNode(concat_20_out_repr()) + ->assert_is_op_output("concat", "Out") + ->assert_is_op_input("reshape2", "X") + ->AsIntermediate(); + auto reshape2_30_op = + pattern->NewNode(reshape2_30_op_repr())->assert_is_op("reshape2"); + auto reshape2_30_out = pattern->NewNode(reshape2_30_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("layer_norm", "X") + ->AsIntermediate(); + auto layernorm_40_op = + pattern->NewNode(layernorm_40_op_repr()) + ->assert_is_op("layer_norm") + ->assert_more([&](Node *node) { + return node->Op()->HasAttr("begin_norm_axis") && + (PADDLE_GET_CONST( + int, node->Op()->GetAttr("begin_norm_axis")) == 2); + }); + auto layernorm_40_in_bias = pattern->NewNode(layernorm_40_in_bias_repr()) + ->assert_is_op_input("layer_norm", "Bias") + ->AsInput(); + auto layernorm_40_in_scale = pattern->NewNode(layernorm_40_in_scale_repr()) + ->assert_is_op_input("layer_norm", "Scale") + ->AsInput(); + auto layernorm_40_out = pattern->NewNode(layernorm_40_out_repr()) + ->assert_is_op_output("layer_norm", "Y") + ->AsOutput(); + + reshape2_00_op->LinksFrom({in}); + reshape2_00_out->LinksFrom({reshape2_00_op}); + strided_slice_10_op->LinksFrom({reshape2_00_out}); + strided_slice_10_out->LinksFrom({strided_slice_10_op}); + strided_slice_11_op->LinksFrom({reshape2_00_out}); + strided_slice_11_out->LinksFrom({strided_slice_11_op}); + strided_slice_12_op->LinksFrom({reshape2_00_out}); + strided_slice_12_out->LinksFrom({strided_slice_12_op}); + strided_slice_13_op->LinksFrom({reshape2_00_out}); + strided_slice_13_out->LinksFrom({strided_slice_13_op}); + concat_20_op->LinksFrom({strided_slice_10_out, + strided_slice_11_out, + strided_slice_12_out, + strided_slice_13_out}); + concat_20_out->LinksFrom({concat_20_op}); + reshape2_30_op->LinksFrom({concat_20_out}); + reshape2_30_out->LinksFrom({reshape2_30_op}); + layernorm_40_op->LinksFrom( + {reshape2_30_out, layernorm_40_in_bias, layernorm_40_in_scale}); + layernorm_40_out->LinksFrom({layernorm_40_op}); + return layernorm_40_out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 5ce1dabc141c45a95c0fd4e32580f2decf321829..bd38b2123e9de4362a5617baecbce4d68fbd01a1 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1946,6 +1946,33 @@ struct LayernormShiftPartitionPattern : public PatternBase { PATTERN_DECL_NODE(reshape4_out); }; +// pattern for merge_layernorm +struct MergeLayernormPattern : public PatternBase { + MergeLayernormPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "merge_layernorm") {} + + PDNode* operator()(PDNode* reshape2_in); + + PATTERN_DECL_NODE(reshape2_00_op); + PATTERN_DECL_NODE(reshape2_00_out); + PATTERN_DECL_NODE(strided_slice_10_op); + PATTERN_DECL_NODE(strided_slice_10_out); + PATTERN_DECL_NODE(strided_slice_11_op); + PATTERN_DECL_NODE(strided_slice_11_out); + PATTERN_DECL_NODE(strided_slice_12_op); + PATTERN_DECL_NODE(strided_slice_12_out); + PATTERN_DECL_NODE(strided_slice_13_op); + PATTERN_DECL_NODE(strided_slice_13_out); + PATTERN_DECL_NODE(concat_20_op); + PATTERN_DECL_NODE(concat_20_out); + PATTERN_DECL_NODE(reshape2_30_op); + PATTERN_DECL_NODE(reshape2_30_out); + PATTERN_DECL_NODE(layernorm_40_op); + PATTERN_DECL_NODE(layernorm_40_in_bias); + PATTERN_DECL_NODE(layernorm_40_in_scale); + PATTERN_DECL_NODE(layernorm_40_out); +}; + // Add support int8 flag struct AddSupportInt8 : public PatternBase { AddSupportInt8(PDPattern* pattern, const std::string& name_scope) diff --git a/paddle/fluid/framework/ir/merge_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/merge_layernorm_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e6aaa37808aee6c21f75d45bd31de5594ee1294 --- /dev/null +++ b/paddle/fluid/framework/ir/merge_layernorm_fuse_pass.cc @@ -0,0 +1,189 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/merge_layernorm_fuse_pass.h" + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(reshape2_00_op); \ + GET_IR_NODE(reshape2_00_out); \ + GET_IR_NODE(strided_slice_10_op); \ + GET_IR_NODE(strided_slice_10_out); \ + GET_IR_NODE(strided_slice_11_op); \ + GET_IR_NODE(strided_slice_11_out); \ + GET_IR_NODE(strided_slice_12_op); \ + GET_IR_NODE(strided_slice_12_out); \ + GET_IR_NODE(strided_slice_13_op); \ + GET_IR_NODE(strided_slice_13_out); \ + GET_IR_NODE(concat_20_op); \ + GET_IR_NODE(concat_20_out); \ + GET_IR_NODE(reshape2_30_op); \ + GET_IR_NODE(reshape2_30_out); \ + GET_IR_NODE(layernorm_40_op); \ + GET_IR_NODE(layernorm_40_in_bias); \ + GET_IR_NODE(layernorm_40_in_scale); \ + GET_IR_NODE(layernorm_40_out); +namespace paddle { +namespace framework { +namespace ir { +MergeLayernormFusePass::MergeLayernormFusePass() { + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsOptional() + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); + AddOpCompat(OpCompat("strided_slice")) + .AddInput("Input") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axes") + .IsType>() + .End() + .AddAttr("starts") + .IsType>() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("infer_flags") + .IsType>() + .End() + .AddAttr("ends") + .IsType>() + .End(); + AddOpCompat(OpCompat("concat")) + .AddInput("X") + .End() + .AddInput("AxisTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .End(); + AddOpCompat(OpCompat("layer_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Variance") + .IsTensor() + .IsOptional() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(0.001f) + .End() + .AddAttr("begin_norm_axis") + .IsNumEQ(2) + .End(); +} +void MergeLayernormFusePass::ApplyImpl(ir::Graph* graph) const { + GraphPatternDetector gpd; + const std::string pattern_name = "merge_layernorm"; + FusePassBase::Init(pattern_name, graph); + // auto* scope = param_scope(); + + PDNode* x = gpd.mutable_pattern() + ->NewNode("x") + ->assert_is_op_input("reshape2", "X") + ->AsInput(); + patterns::MergeLayernormPattern pattern(gpd.mutable_pattern(), pattern_name); + pattern(x); + int fusion_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } + GET_NODES; + OpDesc merge_layer_op_desc(reshape2_00_op->Op()->Block()); + merge_layer_op_desc.SetType("merge_layernorm"); + merge_layer_op_desc.SetInput("X", {subgraph.at(x)->Name()}); + merge_layer_op_desc.SetInput("Bias", {layernorm_40_in_bias->Name()}); + merge_layer_op_desc.SetInput("Scale", {layernorm_40_in_scale->Name()}); + merge_layer_op_desc.SetOutput("Y", {layernorm_40_out->Name()}); + merge_layer_op_desc.SetAttr( + "begin_norm_axis", layernorm_40_op->Op()->GetAttr("begin_norm_axis")); + merge_layer_op_desc.SetAttr("epsilon", + layernorm_40_op->Op()->GetAttr("epsilon")); + auto* merge_layer_op_node = graph->CreateOpNode(&merge_layer_op_desc); + IR_NODE_LINK_TO(subgraph.at(x), merge_layer_op_node); + IR_NODE_LINK_TO(layernorm_40_in_bias, merge_layer_op_node); + IR_NODE_LINK_TO(layernorm_40_in_scale, merge_layer_op_node); + IR_NODE_LINK_TO(merge_layer_op_node, layernorm_40_out); + GraphSafeRemoveNodes(graph, + {reshape2_00_op, + reshape2_00_out, + strided_slice_10_op, + strided_slice_10_out, + strided_slice_11_op, + strided_slice_11_out, + strided_slice_12_op, + strided_slice_12_out, + strided_slice_13_op, + strided_slice_13_out, + concat_20_op, + concat_20_out, + reshape2_30_op, + reshape2_30_out, + layernorm_40_op}); + ++fusion_count; + }; + gpd(graph, handler); + AddStatis(fusion_count); +} +} // namespace ir +} // namespace framework +} // namespace paddle +REGISTER_PASS(merge_layernorm_fuse_pass, + paddle::framework::ir::MergeLayernormFusePass); +REGISTER_PASS_CAPABILITY(merge_layernorm_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("reshape2", 0) + .EQ("concat", 0) + .EQ("layer_norm", 0)); diff --git a/paddle/fluid/framework/ir/merge_layernorm_fuse_pass.h b/paddle/fluid/framework/ir/merge_layernorm_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..9ad47560aee92025ea499171c59982a38074324e --- /dev/null +++ b/paddle/fluid/framework/ir/merge_layernorm_fuse_pass.h @@ -0,0 +1,55 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +// Fusing of path merge and layer_norm +// op: ss=stride_slice +// shape: [ss] = [?x28x28x96] +// input +// | [?x3136x96] +// reshape2 input +// | [?x56x56x96] | [?x3136x96] +// |------|------|------| merge_layernorm +// ss ss ss ss -> | [?x784x384] +// | [ss] | [ss] | [ss] | [ss] fused output +// |------|------|------| +// concat +// | [?x28x28x384] +// reshape2 +// | [?x784x384] +// layer_norm +// | [?x784x384] +// output +class MergeLayernormFusePass : public FusePassBase { + public: + MergeLayernormFusePass(); + virtual ~MergeLayernormFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 226d8acdc11db2d1aff11396f02bceaa4437fe9b..f4107975cba7000c12cb2b826d2cbacbec018e1c 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2260,6 +2260,7 @@ USE_TRT_CONVERTER(shape) USE_TRT_CONVERTER(fill_constant) USE_TRT_CONVERTER(fused_token_prune) USE_TRT_CONVERTER(layernorm_shift_partition) +USE_TRT_CONVERTER(merge_layernorm) USE_TRT_CONVERTER(generic_plugin_creater) USE_TRT_CONVERTER(custom_plugin_creater) USE_TRT_CONVERTER(lookup_table) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 35537cd1fc692bf596b94f22c35deee2181864f8..931aad80ce241a0271f9897cb808d23e25ef90df 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -110,6 +110,7 @@ const std::vector kTRTSubgraphPasses({ "trt_skip_layernorm_fuse_pass", // "preln_skip_layernorm_fuse_pass", // "layernorm_shift_partition_fuse_pass", // + "merge_layernorm_fuse_pass", // "preln_residual_bias_fuse_pass", // // "set_transformer_input_convert_pass", // "conv_bn_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 75c12bf7ca71e1dde32ce625f8186e3e2cb43c47..441cc151462b7b9d119a9e4c9eb7ed2dd9e4d138 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -77,6 +77,7 @@ list( fill_constant_op.cc fused_token_prune_op.cc layernorm_shift_partition_op.cc + merge_layernorm_op.cc generic_and_custom_plugin_creater.cc fused_lookup_tables_op.cc expand_v2_op.cc) diff --git a/paddle/fluid/inference/tensorrt/convert/merge_layernorm_op.cc b/paddle/fluid/inference/tensorrt/convert/merge_layernorm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d47ad62d651e7f0b090a289bb6d8dabedf9ac494 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/merge_layernorm_op.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +class MergeLayernormOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert a fluid merge_layernorm op to tensorrt merge_layernorm " + "plugin"; + framework::OpDesc op_desc(op, nullptr); + auto* X = engine_->GetITensor(op_desc.Input("X").front()); + auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front()); + auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front()); + const int begin_norm_axis = + op_desc.HasAttr("begin_norm_axis") + ? PADDLE_GET_CONST(int, op_desc.GetAttr("begin_norm_axis")) + : 1; + const float eps = op_desc.HasAttr("epsilon") + ? PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")) + : 1e-5f; + PADDLE_ENFORCE_NOT_NULL( + Bias_v, + platform::errors::InvalidArgument( + "Input(Bias) of layer_norm should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + Scale_v, + platform::errors::InvalidArgument( + "Input(Scale) of layer_norm should not be null.")); + PADDLE_ENFORCE_EQ( + begin_norm_axis, + 2, + platform::errors::InvalidArgument( + "The begin_norm_axis of LayernormShiftPartition should be %d", + begin_norm_axis)); + auto* Bias_t = Bias_v->GetMutable(); + auto* Scale_t = Scale_v->GetMutable(); + + auto bias_weight = + engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *Bias_t); + auto scale_weight = + engine_->GetFp32TrtWeight(op_desc.Input("Scale").front(), *Scale_t); + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + nvinfer1::ILayer* merge_layernorm_layer = nullptr; + if (engine_->with_dynamic_shape()) { + plugin::MergeLayernormPluginDynamic* plugin = + new plugin::MergeLayernormPluginDynamic( + static_cast(bias_weight.get().values), + bias_weight.get().count, + static_cast(scale_weight.get().values), + scale_weight.get().count, + eps, + begin_norm_axis, + with_fp16); + merge_layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Currently, MergeLayernorm TRT Plugin only support dynamic shape " + "mode.")); + } + auto output_name = op_desc.Output("Y").front(); + RreplenishLayerAndOutput( + merge_layernorm_layer, "merge_layernorm", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(merge_layernorm, MergeLayernormOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 60cd3887c120fc08d7fa5e19c416648c68061dd3..0e98afbac7f160d02beb44ddd2f32a578ad8b4e3 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2100,6 +2100,13 @@ struct SimpleOpTypeSetTeller : public Teller { return false; } } + if (op_type == "merge_layernorm") { + if (!with_dynamic_shape) { + VLOG(3) << "The merge_layernorm op does not support " + "static shape yet"; + return false; + } + } if (op_type == "lookup_table") { if (!with_dynamic_shape) { @@ -2369,6 +2376,7 @@ struct SimpleOpTypeSetTeller : public Teller { "unsqueeze2", "fused_token_prune", "layernorm_shift_partition", + "merge_layernorm", "lookup_table", "lookup_table_v2", "expand_v2"}; diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 1d5db4ee57f97163417c5c3a2237bd9978d4b55a..54d50ffabc09c2cf2d1d951d9b3adc956b7d9aa5 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -33,6 +33,7 @@ list( preln_residual_bias_plugin.cu fused_token_prune_op_plugin.cu layernorm_shift_partition_op.cu + merge_layernorm_op_plugin.cu generic_plugin.cu lookup_table.cu) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) diff --git a/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..c5afe8f34f49521eb3f087acc2f8413c8b9dd2ac --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu @@ -0,0 +1,339 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.h" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#define FINAL_MASK 0xffffffff + +template +__global__ void merge_layernorm_v2(T *out, + const T *__restrict input, + const T *__restrict gamma, + const T *__restrict beta, + const float layernorm_eps, + int batch, + int H, + int W, + int n) { + // input is [batch, 2*H, 2*W, n/4] + // output is [batch, H, W, n] + // grid (W, H, batch) + // block (n) + const int kIte = 4; + const int tid = threadIdx.x; + const int W_idx = blockIdx.x; + const int H_idx = blockIdx.y; + const size_t batch_offset = blockIdx.z * H * W * n; + const int input_H_stride = W * n / 2; + const int output_H_stride = W * n; + const int n_4 = n >> 2; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + float local_out[kIte]; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < kIte; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + int part_id = col_id / n_4; + int offset_in_W = part_id / 2; + int offset_in_H = part_id % 2; + size_t input_id = batch_offset + + (2 * H_idx + offset_in_H) * input_H_stride + + (2 * W_idx + offset_in_W) * n_4 + (col_id % n_4); + local_out[i] = static_cast(__ldg(input + input_id)); + sum += local_out[i]; + } + } + + mean = phi::funcs::blockReduceSum(sum, FINAL_MASK); + if (tid == 0) { + s_mean = mean / n; + } + __syncthreads(); + + float var = 0.0f; +#pragma unroll + for (int i = 0; i < kIte; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + local_out[i] = local_out[i] - s_mean; + var += local_out[i] * local_out[i]; + } + } + + variance = phi::funcs::blockReduceSum(var, FINAL_MASK); + if (tid == 0) { + s_variance = rsqrtf(variance / n + layernorm_eps); + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < kIte; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + size_t output_idx = + batch_offset + H_idx * output_H_stride + W_idx * n + col_id; + out[output_idx] = + static_cast(local_out[i] * s_variance * + static_cast(__ldg(&gamma[col_id])) + + static_cast(__ldg(&beta[col_id]))); + } + } +} + +template +void invokeMergeLayernorm(T *output, + const T *input, + const T *gamma, + const T *beta, + float layernorm_eps, + int batch, + int H, + int W, + int n, + cudaStream_t stream) { + if ((W % 2 != 0) || (H % 2 != 0)) { + PADDLE_THROW(platform::errors::InvalidArgument( + "H(W) of merge layernorm should be a multiple of 2.")); + } + dim3 grid(W / 2, H / 2, batch); + int blockSize = (n + 31) / 32 * 32; + merge_layernorm_v2<<>>( + output, input, gamma, beta, layernorm_eps, batch, H / 2, W / 2, n * 4); +} + +template void invokeMergeLayernorm(float *output, + const float *input, + const float *gamma, + const float *beta, + float layernorm_eps, + int batch, + int H, + int W, + int n, + cudaStream_t stream); + +template void invokeMergeLayernorm(half *output, + const half *input, + const half *gamma, + const half *beta, + float layernorm_eps, + int batch, + int H, + int W, + int n, + cudaStream_t stream); + +template +static void convertAndCopy(const std::vector &host, T *dev) { + T *host_ptr = new T[host.size()]; + std::transform(host.begin(), host.end(), host_ptr, [](float x) { + return static_cast(x); + }); + cudaMemcpy(dev, host_ptr, sizeof(T) * host.size(), cudaMemcpyHostToDevice); + delete host_ptr; +} + +void MergeLayernormPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT {} + +MergeLayernormPluginDynamic::MergeLayernormPluginDynamic( + const float *bias_d, + const size_t bias_num, + const float *scale_d, + const size_t scale_num, + const float eps, + const int begin_norm_axis, + const bool with_fp16, + std::shared_ptr bias_device, + std::shared_ptr scale_device) + : eps_(eps), + begin_norm_axis_(begin_norm_axis), + with_fp16_(with_fp16), + bias_device_(bias_device), + scale_device_(scale_device) { + bias_.resize(bias_num); + scale_.resize(scale_num); + std::copy(bias_d, bias_d + bias_num, bias_.data()); + std::copy(scale_d, scale_d + scale_num, scale_.data()); + int type_size = with_fp16_ ? sizeof(half) : sizeof(float); + if (bias_device_ == nullptr) { + void *p; + cudaMalloc(&p, bias_num * type_size); + bias_device_.reset(p, [](void *ptr) { cudaFree(ptr); }); + + if (with_fp16) { + convertAndCopy(bias_, reinterpret_cast(p)); + } else { + convertAndCopy(bias_, reinterpret_cast(p)); + } + } + if (scale_device_ == nullptr) { + void *p; + cudaMalloc(&p, scale_num * type_size); + scale_device_.reset(p, [](void *ptr) { cudaFree(ptr); }); + if (with_fp16) { + convertAndCopy(scale_, reinterpret_cast(p)); + } else { + convertAndCopy(scale_, reinterpret_cast(p)); + } + } +} + +bool MergeLayernormPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc *in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_NOT_NULL( + in_out, + platform::errors::InvalidArgument("The input of MergeLayernorm " + "plugin shoule not be nullptr.")); + PADDLE_ENFORCE_LT( + pos, + nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, + nb_inputs + nb_outputs)); + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + if (with_fp16_) { + return in.type == nvinfer1::DataType::kHALF && + in.format == nvinfer1::TensorFormat::kLINEAR; + } else { + return in.type == nvinfer1::DataType::kFLOAT && + in.format == nvinfer1::TensorFormat::kLINEAR; + } + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType MergeLayernormPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType *input_types, + int nb_inputs) const TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(index, + 0, + platform::errors::InvalidArgument( + "The MergeLayernorm only has one input, so the " + "index value should be 0, but get %d.", + index)); + return input_types[0]; +} + +nvinfer1::DimsExprs MergeLayernormPluginDynamic::getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs *inputs, + int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(output_index, + 0, + platform::errors::InvalidArgument( + "There is only one output of the MergeLayernorm, " + "so the index should be zero," + "but it's (%d)", + output_index)); + PADDLE_ENFORCE_EQ( + nb_inputs, + 1, + platform::errors::InvalidArgument( + "The Input of the MergeLayernorm should be 1, but we found " + "it has (%d) inputs", + nb_inputs)); + nvinfer1::DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = expr_builder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, + *inputs[0].d[1], + *expr_builder.constant(4)); + ret.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kPROD, + *inputs[0].d[2], + *expr_builder.constant(4)); + return ret; +} + +int MergeLayernormPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, + const void *const *inputs, + void *const *outputs, + void *workspace, + cudaStream_t stream) TRT_NOEXCEPT { + const auto &input_dims = input_desc[0].dims; + auto input_type = input_desc[0].type; + int batch = input_dims.d[0]; + int input_resolution = static_cast(std::sqrt(input_dims.d[1])); + int dim = static_cast(input_dims.d[2]); + PADDLE_ENFORCE_EQ( + input_resolution * input_resolution, + input_dims.d[1], + platform::errors::InvalidArgument( + "The MergeLayernorm TRT Plugin get invalid input_resolution %d", + input_resolution)); + + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(3) << "TRT Plugin DataType selected. MergeLayernorm-->fp32"; + invokeMergeLayernorm( + reinterpret_cast(outputs[0]), + reinterpret_cast(inputs[0]), + reinterpret_cast(scale_device_.get()), + reinterpret_cast(bias_device_.get()), + eps_, + batch, + input_resolution, + input_resolution, + dim, + stream); + } else if (input_type == nvinfer1::DataType::kHALF) { + VLOG(3) << "TRT Plugin DataType selected. MergeLayernorm-->fp16"; + invokeMergeLayernorm( + reinterpret_cast(outputs[0]), + reinterpret_cast(inputs[0]), + reinterpret_cast(scale_device_.get()), + reinterpret_cast(bias_device_.get()), + eps_, + batch, + input_resolution, + input_resolution, + dim, + stream); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The MergeLayernorm TRT Plugin's input type should be float or half.")); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..cef07b369730ab47dd979f56d6e3dd3d16b7a932 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.h @@ -0,0 +1,141 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { +class MergeLayernormPluginDynamic : public DynamicPluginTensorRT { + public: + MergeLayernormPluginDynamic(const float* bias_d, + const size_t bias_num, + const float* scale_d, + const size_t scale_num, + const float eps, + const int begin_norm_axis, + const bool with_fp16, + std::shared_ptr bias_device = nullptr, + std::shared_ptr scale_device = nullptr); + + MergeLayernormPluginDynamic(void const* serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &bias_); + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &eps_); + DeserializeValue(&serialData, &serialLength, &begin_norm_axis_); + DeserializeValue(&serialData, &serialLength, &with_fp16_); + } + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + return new MergeLayernormPluginDynamic(bias_.data(), + bias_.size(), + scale_.data(), + scale_.size(), + eps_, + begin_norm_axis_, + with_fp16_, + bias_device_, + scale_device_); + } + const char* getPluginType() const TRT_NOEXCEPT override { + return "merge_layernorm_plugin_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + int initialize() TRT_NOEXCEPT override { return 0; } + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(bias_) + SerializedSize(scale_) + + SerializedSize(eps_) + SerializedSize(begin_norm_axis_) + + SerializedSize(with_fp16_); + } + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, bias_); + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, eps_); + SerializeValue(&buffer, begin_norm_axis_); + SerializeValue(&buffer, with_fp16_); + } + nvinfer1::DimsExprs getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) // NOLINT + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + private: + std::vector bias_; + std::vector scale_; + float eps_; + int begin_norm_axis_; + bool with_fp16_; + std::shared_ptr bias_device_ = nullptr; + std::shared_ptr scale_device_ = nullptr; +}; +class MergeLayernormPluginDynamicCreator : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "merge_layernorm_plugin_dynamic"; + } + + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + return new MergeLayernormPluginDynamic(serial_data, serial_length); + } +}; + +REGISTER_TRT_PLUGIN_V2(MergeLayernormPluginDynamicCreator); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 731ec7306fba5e3bb66b0692f7f957deb875c3c0..35ece8ffbf3910f964dea47de48d3b067ab561fe 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -122,6 +122,7 @@ if(WITH_GPU AND TENSORRT_FOUND) #set_tests_properties(test_trt_multiclass_nms_op PROPERTIES TIMEOUT 200) set_tests_properties(test_trt_dynamic_shape PROPERTIES TIMEOUT 120) set_tests_properties(test_trt_inspector PROPERTIES TIMEOUT 60) + set_tests_properties(test_merge_layernorm_fuse_pass PROPERTIES TIMEOUT 180) if(WITH_NV_JETSON) set_tests_properties( test_trt_pool_op diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_merge_layernorm_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_merge_layernorm_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..cd25181f501e4b5f0b08770cc64537eea99ecd00 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_merge_layernorm_fuse_pass.py @@ -0,0 +1,215 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +import unittest +import hypothesis.strategies as st + + +class TestMergeLayernormFusePass(PassAutoScanTest): + # input + # | [?x3136x96] + # reshape2 input + # | [?x56x56x96] | [?x3136x96] + # |--------------|--------------|--------------| merge_layernorm + # strided_slice strided_slice strided_slice strided_slice -> | [?x784x384] + # | [?x28x28x96] | [?x28x28x96] | [?x28x28x96] | fused output + # |--------------|--------------|--------------| + # concat + # | [?x28x28x384] + # reshape2 + # | [?x784x384] + # layer_norm + # | [?x784x384] + # output + + def sample_predictor_configs(self, program_config): + # trt dynamic_shape fp32 + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=1 << 20, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + config.set_trt_dynamic_shape_info({"input_data": [1, 196, 96]}, + {"input_data": [4, 3136, 384]}, + {"input_data": [1, 3136, 96]}) + yield config, ["merge_layernorm"], (1e-5, 1e-5) + # trt dynamic_shape fp16 + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=1 << 20, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Half, + use_static=False, + use_calib_mode=False) + config.set_trt_dynamic_shape_info({"input_data": [1, 196, 96]}, + {"input_data": [4, 3136, 384]}, + {"input_data": [1, 3136, 96]}) + yield config, ["merge_layernorm"], (1e-3, 1e-3) + + def sample_program_config(self, draw): + batch_size = draw(st.integers(min_value=1, max_value=4)) + input_H_W = draw(st.sampled_from([56, 28, 14])) + input_n = draw(st.sampled_from([96, 192, 384])) + layernorm_40_begin_norm_axis = 2 + layernorm_40_epsilon = draw( + st.floats(min_value=0.0000001, max_value=0.001)) + + def generate_input(attrs): + return np.random.random([ + attrs[3]['batch_size'], + attrs[3]['input_H_W'] * attrs[3]['input_H_W'], + attrs[3]['input_n'] + ]).astype(np.float32) + + def generate_weight(attrs): + return np.random.random([attrs[3]['input_n'] * 4 + ]).astype(np.float32) + + attrs = [{ + 'shape': [-1, input_H_W, input_H_W, input_n] + }, { + 'shape': [-1, int(input_H_W * input_H_W / 4), + int(input_n * 4)] + }, { + 'begin_norm_axis': layernorm_40_begin_norm_axis, + 'epsilon': layernorm_40_epsilon + }, { + 'batch_size': batch_size, + 'input_H_W': input_H_W, + 'input_n': input_n + }] + reshape2_00_op = OpConfig(type="reshape2", + inputs={'X': ['input_data']}, + outputs={ + 'Out': ['reshape2_00_out'], + 'XShape': ['reshape2_00_outxshape'] + }, + attrs={'shape': attrs[0]['shape']}) + strided_slice_10_op = OpConfig( + type="strided_slice", + inputs={'Input': ['reshape2_00_out']}, + outputs={'Out': ['strided_slice_10_out']}, + attrs={ + 'axes': [1, 2], + 'starts': [0, 0], + 'infer_flags': [1, 1], + 'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']], + 'strides': [2, 2] + }) + strided_slice_11_op = OpConfig( + type="strided_slice", + inputs={'Input': ['reshape2_00_out']}, + outputs={'Out': ['strided_slice_11_out']}, + attrs={ + 'axes': [1, 2], + 'starts': [1, 0], + 'infer_flags': [1, 1], + 'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']], + 'strides': [2, 2] + }) + strided_slice_12_op = OpConfig( + type="strided_slice", + inputs={'Input': ['reshape2_00_out']}, + outputs={'Out': ['strided_slice_12_out']}, + attrs={ + 'axes': [1, 2], + 'starts': [0, 1], + 'infer_flags': [1, 1], + 'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']], + 'strides': [2, 2] + }) + strided_slice_13_op = OpConfig( + type="strided_slice", + inputs={'Input': ['reshape2_00_out']}, + outputs={'Out': ['strided_slice_13_out']}, + attrs={ + 'axes': [1, 2], + 'starts': [1, 1], + 'infer_flags': [1, 1], + 'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']], + 'strides': [2, 2] + }) + concat_20_op = OpConfig(type="concat", + inputs={ + 'X': [ + 'strided_slice_10_out', + 'strided_slice_11_out', + 'strided_slice_12_out', + 'strided_slice_13_out' + ] + }, + outputs={'Out': ['concat_20_out']}, + attrs={'axis': -1}) + reshape2_30_op = OpConfig(type='reshape2', + inputs={'X': ['concat_20_out']}, + outputs={ + 'Out': ['reshape2_30_Out'], + 'XShape': ['reshape2_30_XShape'] + }, + attrs={'shape': attrs[1]['shape']}) + layernorm_40_op = OpConfig(type='layer_norm', + inputs={ + 'X': ['reshape2_30_Out'], + 'Bias': ['layer_norm_bias'], + 'Scale': ['layer_norm_scale'] + }, + outputs={ + "Y": ["layer_norm_out"], + "Mean": ["layer_norm_outMean"], + "Variance": ["layer_norm_outVariance"] + }, + attrs={ + 'begin_norm_axis': + attrs[2]['begin_norm_axis'], + 'epsilon': + attrs[2]['epsilon'] + }) + program_config = ProgramConfig( + ops=[ + reshape2_00_op, strided_slice_10_op, strided_slice_11_op, + strided_slice_12_op, strided_slice_13_op, concat_20_op, + reshape2_30_op, layernorm_40_op + ], + weights={ + 'layer_norm_bias': + TensorConfig(data_gen=partial(generate_weight, attrs)), + 'layer_norm_scale': + TensorConfig(data_gen=partial(generate_weight, attrs)) + }, + inputs={ + 'input_data': + TensorConfig(data_gen=partial(generate_input, attrs)) + }, + outputs=['layer_norm_out']) + return program_config + + def test(self): + self.run_and_statis(quant=False, + max_examples=50, + passes=["merge_layernorm_fuse_pass"], + max_duration=250, + min_success_num=50) + + +if __name__ == "__main__": + unittest.main()