From 1c6013ddceb00e6cd06f19e2bb5823f556238aa5 Mon Sep 17 00:00:00 2001 From: wenbin Date: Thu, 10 Nov 2022 11:40:12 +0800 Subject: [PATCH] skip_merge_layernorm (#47810) * skip_merge_layernorm * add UT * modify comments --- .../ir/preln_layernorm_x_fuse_pass.cc | 132 +++++-- .../ir/preln_layernorm_x_fuse_pass.h | 12 +- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../convert/skip_merge_layernorm_op.cc | 94 +++++ paddle/fluid/inference/tensorrt/op_teller.cc | 11 + .../inference/tensorrt/plugin/CMakeLists.txt | 1 + .../plugin/merge_layernorm_op_plugin.cu | 14 - .../plugin/skip_merge_layernorm_op_plugin.cu | 340 ++++++++++++++++++ .../plugin/skip_merge_layernorm_op_plugin.h | 141 ++++++++ .../unittests/ir/inference/CMakeLists.txt | 2 + .../test_skip_merge_layernorm_fuse_pass.py | 250 +++++++++++++ 12 files changed, 956 insertions(+), 43 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/skip_merge_layernorm_op.cc create mode 100644 paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.h create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_skip_merge_layernorm_fuse_pass.py diff --git a/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.cc b/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.cc index e6ab0f01ea..5f4c59333f 100644 --- a/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.cc @@ -36,7 +36,7 @@ struct PrelnLayerNormX : public PatternBase { PrelnLayerNormX(PDPattern *pattern, const std::string &name_scope) : PatternBase(pattern, name_scope, "preln_layernorm_x") {} - void operator()(PDNode *x, PDNode *y); + void operator()(PDNode *x, PDNode *y, const std::string &norm_type); // declare operator node's name PATTERN_DECL_NODE(elementwise_bias); PATTERN_DECL_NODE(elementwise0); @@ -51,34 +51,33 @@ struct PrelnLayerNormX : public PatternBase { PATTERN_DECL_NODE(layer_norm_out); }; -void PrelnLayerNormX::operator()(PDNode *x, PDNode *y) { +void PrelnLayerNormX::operator()(PDNode *x, + PDNode *y, + const std::string &norm_type) { auto *elementwise1 = pattern->NewNode(elementwise1_repr())->assert_is_op("elementwise_add"); auto *elementwise1_out_var = pattern->NewNode(elementwise1_out_repr()) ->assert_is_op_output("elementwise_add", "Out") - ->assert_is_op_input("layernorm_shift_partition", "X"); + ->assert_is_op_input(norm_type, "X"); elementwise1->LinksFrom({x, y}).LinksTo({elementwise1_out_var}); // Create nodes for layer_norm op. - auto *layer_norm = pattern->NewNode(layer_norm_repr()) - ->assert_is_op("layernorm_shift_partition"); - auto *layer_norm_bias_var = - pattern->NewNode(layer_norm_bias_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("layernorm_shift_partition", "Bias"); - - auto *layer_norm_scale_var = - pattern->NewNode(layer_norm_scale_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("layernorm_shift_partition", "Scale"); - - auto *layer_norm_out_var = - pattern->NewNode(layer_norm_out_repr()) - ->AsOutput() - ->assert_is_op_output("layernorm_shift_partition", "Y"); + auto *layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op(norm_type); + auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input(norm_type, "Bias"); + + auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input(norm_type, "Scale"); + + auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output(norm_type, "Y"); // Add links for layer_norm op. layer_norm @@ -89,7 +88,8 @@ void PrelnLayerNormX::operator()(PDNode *x, PDNode *y) { } // namespace patterns -int PrelnLayerNormXFusePass::ApplyPattern(ir::Graph *graph) const { +int PrelnLayerNormXFusePass::ApplyLayerNormShiftPattern( + ir::Graph *graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init("preln_layernorm_x_fuse", graph); @@ -113,7 +113,7 @@ int PrelnLayerNormXFusePass::ApplyPattern(ir::Graph *graph) const { ->assert_is_op_input("elementwise_add", "Y"); patterns::PrelnLayerNormX fused_pattern(gpd.mutable_pattern(), "preln_layernorm_x_fuse"); - fused_pattern(x, y); + fused_pattern(x, y, "layernorm_shift_partition"); auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *graph) { @@ -137,10 +137,7 @@ int PrelnLayerNormXFusePass::ApplyPattern(ir::Graph *graph) const { LOG(WARNING) << "preln_layernorm_x_fuse pass in op compat failed."; return; } - static int cnt = 0; - if (cnt++ > 0) { - // return; - } + std::unordered_set del_node_set; // Create an PrelnLayerNormX op node OpDesc new_desc(*layer_norm->Op()); @@ -171,9 +168,88 @@ int PrelnLayerNormXFusePass::ApplyPattern(ir::Graph *graph) const { return found_subgraph_count; } +int PrelnLayerNormXFusePass::ApplyMergeLayerNormPattern( + ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init("preln_layernorm_x_fuse", graph); + + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + PDNode *x = nullptr; + PDNode *y = nullptr; + + x = gpd.mutable_pattern() + ->NewNode("preln_layernorm_x_fuse/x") + ->AsInput() + ->assert_var_not_persistable() + ->assert_is_op_input("elementwise_add", "X"); + + y = gpd.mutable_pattern() + ->NewNode("preln_layernorm_x_fuse/y") + ->AsInput() + ->assert_var_not_persistable() + ->assert_is_op_input("elementwise_add", "Y"); + patterns::PrelnLayerNormX fused_pattern(gpd.mutable_pattern(), + "preln_layernorm_x_fuse"); + fused_pattern(x, y, "merge_layernorm"); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + + VLOG(4) << "handle preln layernorm x fuse"; + + GET_IR_NODE_FROM_SUBGRAPH(elementwise1, elementwise1, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise1_out, elementwise1_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern); + + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "preln_layernorm_x_fuse pass in op compat failed."; + return; + } + std::unordered_set del_node_set; + // Create an PrelnLayerNormX op node + OpDesc new_desc(*layer_norm->Op()); + new_desc.SetType("skip_merge_layernorm"); + new_desc.SetInput("X", {subgraph.at(x)->Name()}); + new_desc.SetInput("Y", {subgraph.at(y)->Name()}); + new_desc.SetOutput("Out", {layer_norm_out->Name()}); + new_desc.RemoveOutput("Y"); + new_desc.Flush(); + + auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + + del_node_set.insert(elementwise1); + del_node_set.insert(layer_norm); + del_node_set.insert(elementwise1_out); + GraphSafeRemoveNodes(graph, del_node_set); + + IR_NODE_LINK_TO(subgraph.at(x), fused_node); + IR_NODE_LINK_TO(subgraph.at(y), fused_node); + IR_NODE_LINK_TO(layer_norm_scale, fused_node); + IR_NODE_LINK_TO(layer_norm_bias, fused_node); + IR_NODE_LINK_TO(fused_node, layer_norm_out); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + void PrelnLayerNormXFusePass::ApplyImpl(ir::Graph *graph) const { FusePassBase::Init("preln_layernorm_x_fuse", graph); - int found_subgraph_count = ApplyPattern(graph); + int found_subgraph_count = ApplyLayerNormShiftPattern(graph); + found_subgraph_count += ApplyMergeLayerNormPattern(graph); AddStatis(found_subgraph_count); } diff --git a/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h b/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h index dd720e353c..a79800c3be 100644 --- a/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h +++ b/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h @@ -28,6 +28,15 @@ namespace ir { // other_op4 layernorm_shift_partition other_op4 other_op3 // | // other_op3 +// or +// | | | | +// other_op1 other_op2 other_op1 other_op2 +// | | fuse \ / +// |------elementwise_add -> preln_merge_layernorm +// | | | | +// other_op4 merge_layernorm other_op4 other_op3 +// | +// other_op3 class Graph; class PrelnLayerNormXFusePass : public FusePassBase { @@ -52,7 +61,8 @@ class PrelnLayerNormXFusePass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const override; - int ApplyPattern(ir::Graph* graph) const; + int ApplyLayerNormShiftPattern(ir::Graph* graph) const; + int ApplyMergeLayerNormPattern(ir::Graph* graph) const; }; } // namespace ir diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 345e8e6401..d9adc9426b 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2322,6 +2322,7 @@ USE_TRT_CONVERTER(celu) USE_TRT_CONVERTER(layernorm_shift_partition) USE_TRT_CONVERTER(preln_layernorm_shift_partition) USE_TRT_CONVERTER(merge_layernorm) +USE_TRT_CONVERTER(skip_merge_layernorm) USE_TRT_CONVERTER(generic_plugin_creater) USE_TRT_CONVERTER(custom_plugin_creater) USE_TRT_CONVERTER(tanh_shrink) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 72e2e1c7f0..849bccc3c2 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -82,6 +82,7 @@ list( logsigmoid_op.cc preln_layernorm_shift_partition_op.cc merge_layernorm_op.cc + skip_merge_layernorm_op.cc generic_and_custom_plugin_creater.cc fused_lookup_tables_op.cc expand_v2_op.cc) diff --git a/paddle/fluid/inference/tensorrt/convert/skip_merge_layernorm_op.cc b/paddle/fluid/inference/tensorrt/convert/skip_merge_layernorm_op.cc new file mode 100644 index 0000000000..eea2aa0ef3 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/skip_merge_layernorm_op.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +class SkipMergeLayernormOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert a fluid skip_merge_layernorm op to tensorrt " + "skip_merge_layernorm " + "plugin"; + framework::OpDesc op_desc(op, nullptr); + auto* X = engine_->GetITensor(op_desc.Input("X").front()); + auto* Y = engine_->GetITensor(op_desc.Input("Y").front()); + auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front()); + auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front()); + const int begin_norm_axis = + op_desc.HasAttr("begin_norm_axis") + ? PADDLE_GET_CONST(int, op_desc.GetAttr("begin_norm_axis")) + : 1; + const float eps = op_desc.HasAttr("epsilon") + ? PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")) + : 1e-5f; + PADDLE_ENFORCE_NOT_NULL( + Bias_v, + platform::errors::InvalidArgument( + "Input(Bias) of layer_norm should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + Scale_v, + platform::errors::InvalidArgument( + "Input(Scale) of layer_norm should not be null.")); + PADDLE_ENFORCE_EQ( + begin_norm_axis, + 2, + platform::errors::InvalidArgument( + "The begin_norm_axis of SkipLayerLayernorm should be %d", + begin_norm_axis)); + auto* Bias_t = Bias_v->GetMutable(); + auto* Scale_t = Scale_v->GetMutable(); + + auto bias_weight = + engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *Bias_t); + auto scale_weight = + engine_->GetFp32TrtWeight(op_desc.Input("Scale").front(), *Scale_t); + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + nvinfer1::ILayer* skip_merge_layernorm_layer = nullptr; + if (engine_->with_dynamic_shape()) { + plugin::SkipMergeLayernormPluginDynamic* plugin = + new plugin::SkipMergeLayernormPluginDynamic( + static_cast(bias_weight.get().values), + bias_weight.get().count, + static_cast(scale_weight.get().values), + scale_weight.get().count, + eps, + begin_norm_axis, + with_fp16); + std::vector plugin_inputs{X, Y}; + skip_merge_layernorm_layer = + engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Currently, MergeLayernorm TRT Plugin only support dynamic shape " + "mode.")); + } + auto output_name = op_desc.Output("Out").front(); + RreplenishLayerAndOutput(skip_merge_layernorm_layer, + "skip_merge_layernorm", + {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(skip_merge_layernorm, SkipMergeLayernormOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index b4f8b7e929..75bc402deb 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2132,6 +2132,14 @@ struct SimpleOpTypeSetTeller : public Teller { } } + if (op_type == "skip_merge_layernorm") { + if (!with_dynamic_shape) { + VLOG(3) << "The merge_layernorm op does not support " + "static shape yet"; + return false; + } + } + if (op_type == "lookup_table") { if (!with_dynamic_shape) { VLOG(3) << "the lookup_table does not support " @@ -2288,6 +2296,8 @@ struct SimpleOpTypeSetTeller : public Teller { "logsigmoid", "preln_layernorm_shift_partition", "lookup_table", + "merge_layernorm", + "skip_merge_layernorm", // "lookup_table_v2", "expand_v2"}; @@ -2410,6 +2420,7 @@ struct SimpleOpTypeSetTeller : public Teller { "logsigmoid", "preln_layernorm_shift_partition", "merge_layernorm", + "skip_merge_layernorm", "lookup_table", // "lookup_table_v2", "expand_v2"}; diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 9d3d4c5532..a544d18a57 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -34,6 +34,7 @@ list( layernorm_shift_partition_op.cu prelnlayernorm_shift_partition_op.cu merge_layernorm_op_plugin.cu + skip_merge_layernorm_op_plugin.cu generic_plugin.cu lookup_table.cu) diff --git a/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu index c5afe8f34f..d94d4395b3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu @@ -255,20 +255,6 @@ nvinfer1::DimsExprs MergeLayernormPluginDynamic::getOutputDimensions( const nvinfer1::DimsExprs *inputs, int nb_inputs, nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { - PADDLE_ENFORCE_EQ(output_index, - 0, - platform::errors::InvalidArgument( - "There is only one output of the MergeLayernorm, " - "so the index should be zero," - "but it's (%d)", - output_index)); - PADDLE_ENFORCE_EQ( - nb_inputs, - 1, - platform::errors::InvalidArgument( - "The Input of the MergeLayernorm should be 1, but we found " - "it has (%d) inputs", - nb_inputs)); nvinfer1::DimsExprs ret; ret.nbDims = 3; ret.d[0] = inputs[0].d[0]; diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.cu new file mode 100644 index 0000000000..09c971a858 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.cu @@ -0,0 +1,340 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.h" +#include "paddle/phi/kernels/funcs/math_cuda_utils.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#define FINAL_MASK 0xffffffff + +template +__global__ void merge_layernorm_v2(T *out, + const T *__restrict input0, + const T *__restrict input1, + const T *__restrict gamma, + const T *__restrict beta, + const float layernorm_eps, + int batch, + int H, + int W, + int n) { + // input is [batch, 2*H, 2*W, n/4] + // output is [batch, H, W, n] + // grid (W, H, batch) + // block (n) + const int kIte = 4; + const int tid = threadIdx.x; + const int W_idx = blockIdx.x; + const int H_idx = blockIdx.y; + const size_t batch_offset = blockIdx.z * H * W * n; + const int input_H_stride = W * n / 2; + const int output_H_stride = W * n; + const int n_4 = n >> 2; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + float local_out[kIte]; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < kIte; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + int part_id = col_id / n_4; + int offset_in_W = part_id / 2; + int offset_in_H = part_id % 2; + size_t input_id = batch_offset + + (2 * H_idx + offset_in_H) * input_H_stride + + (2 * W_idx + offset_in_W) * n_4 + (col_id % n_4); + local_out[i] = static_cast(__ldg(input0 + input_id)); + local_out[i] += static_cast(__ldg(input1 + input_id)); + sum += local_out[i]; + } + } + + mean = phi::funcs::blockReduceSum(sum, FINAL_MASK); + if (tid == 0) { + s_mean = mean / n; + } + __syncthreads(); + + float var = 0.0f; +#pragma unroll + for (int i = 0; i < kIte; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + local_out[i] = local_out[i] - s_mean; + var += local_out[i] * local_out[i]; + } + } + + variance = phi::funcs::blockReduceSum(var, FINAL_MASK); + if (tid == 0) { + s_variance = rsqrtf(variance / n + layernorm_eps); + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < kIte; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + size_t output_idx = + batch_offset + H_idx * output_H_stride + W_idx * n + col_id; + out[output_idx] = + static_cast(local_out[i] * s_variance * + static_cast(__ldg(&gamma[col_id])) + + static_cast(__ldg(&beta[col_id]))); + } + } +} + +template +void invokeMergeLayernorm(T *output, + const T *input0, + const T *input1, + const T *gamma, + const T *beta, + float layernorm_eps, + int batch, + int H, + int W, + int n, + cudaStream_t stream) { + if ((W % 2 != 0) || (H % 2 != 0)) { + PADDLE_THROW(platform::errors::InvalidArgument( + "H(W) of merge layernorm should be a multiple of 2.")); + } + dim3 grid(W / 2, H / 2, batch); + int blockSize = (n + 31) / 32 * 32; + merge_layernorm_v2<<>>(output, + input0, + input1, + gamma, + beta, + layernorm_eps, + batch, + H / 2, + W / 2, + n * 4); +} + +template void invokeMergeLayernorm(float *output, + const float *input0, + const float *input1, + const float *gamma, + const float *beta, + float layernorm_eps, + int batch, + int H, + int W, + int n, + cudaStream_t stream); + +template void invokeMergeLayernorm(half *output, + const half *input0, + const half *input1, + const half *gamma, + const half *beta, + float layernorm_eps, + int batch, + int H, + int W, + int n, + cudaStream_t stream); + +template +static void convertAndCopy(const std::vector &host, T *dev) { + T *host_ptr = new T[host.size()]; + std::transform(host.begin(), host.end(), host_ptr, [](float x) { + return static_cast(x); + }); + cudaMemcpy(dev, host_ptr, sizeof(T) * host.size(), cudaMemcpyHostToDevice); + delete host_ptr; +} + +void SkipMergeLayernormPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT {} + +SkipMergeLayernormPluginDynamic::SkipMergeLayernormPluginDynamic( + const float *bias_d, + const size_t bias_num, + const float *scale_d, + const size_t scale_num, + const float eps, + const int begin_norm_axis, + const bool with_fp16, + std::shared_ptr bias_device, + std::shared_ptr scale_device) + : eps_(eps), + begin_norm_axis_(begin_norm_axis), + with_fp16_(with_fp16), + bias_device_(bias_device), + scale_device_(scale_device) { + bias_.resize(bias_num); + scale_.resize(scale_num); + std::copy(bias_d, bias_d + bias_num, bias_.data()); + std::copy(scale_d, scale_d + scale_num, scale_.data()); + int type_size = with_fp16_ ? sizeof(half) : sizeof(float); + if (bias_device_ == nullptr) { + void *p; + cudaMalloc(&p, bias_num * type_size); + bias_device_.reset(p, [](void *ptr) { cudaFree(ptr); }); + + if (with_fp16) { + convertAndCopy(bias_, reinterpret_cast(p)); + } else { + convertAndCopy(bias_, reinterpret_cast(p)); + } + } + if (scale_device_ == nullptr) { + void *p; + cudaMalloc(&p, scale_num * type_size); + scale_device_.reset(p, [](void *ptr) { cudaFree(ptr); }); + if (with_fp16) { + convertAndCopy(scale_, reinterpret_cast(p)); + } else { + convertAndCopy(scale_, reinterpret_cast(p)); + } + } +} + +bool SkipMergeLayernormPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc *in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_NOT_NULL( + in_out, + platform::errors::InvalidArgument("The input of MergeLayernorm " + "plugin shoule not be nullptr.")); + PADDLE_ENFORCE_LT( + pos, + nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, + nb_inputs + nb_outputs)); + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + if (with_fp16_) { + return in.type == nvinfer1::DataType::kHALF && + in.format == nvinfer1::TensorFormat::kLINEAR; + } else { + return in.type == nvinfer1::DataType::kFLOAT && + in.format == nvinfer1::TensorFormat::kLINEAR; + } + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType SkipMergeLayernormPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType *input_types, + int nb_inputs) const TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(index, + 0, + platform::errors::InvalidArgument( + "The MergeLayernorm only has one input, so the " + "index value should be 0, but get %d.", + index)); + return input_types[0]; +} + +nvinfer1::DimsExprs SkipMergeLayernormPluginDynamic::getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs *inputs, + int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { + nvinfer1::DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = expr_builder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV, + *inputs[0].d[1], + *expr_builder.constant(4)); + ret.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kPROD, + *inputs[0].d[2], + *expr_builder.constant(4)); + return ret; +} + +int SkipMergeLayernormPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, + const void *const *inputs, + void *const *outputs, + void *workspace, + cudaStream_t stream) TRT_NOEXCEPT { + const auto &input_dims = input_desc[0].dims; + auto input_type = input_desc[0].type; + int batch = input_dims.d[0]; + int input_resolution = static_cast(std::sqrt(input_dims.d[1])); + int dim = static_cast(input_dims.d[2]); + PADDLE_ENFORCE_EQ( + input_resolution * input_resolution, + input_dims.d[1], + platform::errors::InvalidArgument( + "The MergeLayernorm TRT Plugin get invalid input_resolution %d", + input_resolution)); + + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(3) << "TRT Plugin DataType selected. MergeLayernorm-->fp32"; + invokeMergeLayernorm( + reinterpret_cast(outputs[0]), + reinterpret_cast(inputs[0]), + reinterpret_cast(inputs[1]), + reinterpret_cast(scale_device_.get()), + reinterpret_cast(bias_device_.get()), + eps_, + batch, + input_resolution, + input_resolution, + dim, + stream); + } else if (input_type == nvinfer1::DataType::kHALF) { + VLOG(3) << "TRT Plugin DataType selected. MergeLayernorm-->fp16"; + invokeMergeLayernorm( + reinterpret_cast(outputs[0]), + reinterpret_cast(inputs[0]), + reinterpret_cast(inputs[1]), + reinterpret_cast(scale_device_.get()), + reinterpret_cast(bias_device_.get()), + eps_, + batch, + input_resolution, + input_resolution, + dim, + stream); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The MergeLayernorm TRT Plugin's input type should be float or half.")); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.h new file mode 100644 index 0000000000..343f2740ff --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.h @@ -0,0 +1,141 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { +class SkipMergeLayernormPluginDynamic : public DynamicPluginTensorRT { + public: + SkipMergeLayernormPluginDynamic(const float* bias_d, + const size_t bias_num, + const float* scale_d, + const size_t scale_num, + const float eps, + const int begin_norm_axis, + const bool with_fp16, + std::shared_ptr bias_device = nullptr, + std::shared_ptr scale_device = nullptr); + + SkipMergeLayernormPluginDynamic(void const* serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &bias_); + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &eps_); + DeserializeValue(&serialData, &serialLength, &begin_norm_axis_); + DeserializeValue(&serialData, &serialLength, &with_fp16_); + } + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + return new SkipMergeLayernormPluginDynamic(bias_.data(), + bias_.size(), + scale_.data(), + scale_.size(), + eps_, + begin_norm_axis_, + with_fp16_, + bias_device_, + scale_device_); + } + const char* getPluginType() const TRT_NOEXCEPT override { + return "skip_merge_layernorm_plugin_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + int initialize() TRT_NOEXCEPT override { return 0; } + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(bias_) + SerializedSize(scale_) + + SerializedSize(eps_) + SerializedSize(begin_norm_axis_) + + SerializedSize(with_fp16_); + } + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, bias_); + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, eps_); + SerializeValue(&buffer, begin_norm_axis_); + SerializeValue(&buffer, with_fp16_); + } + nvinfer1::DimsExprs getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) // NOLINT + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + private: + std::vector bias_; + std::vector scale_; + float eps_; + int begin_norm_axis_; + bool with_fp16_; + std::shared_ptr bias_device_ = nullptr; + std::shared_ptr scale_device_ = nullptr; +}; +class SkipMergeLayernormPluginDynamicCreator : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "skip_merge_layernorm_plugin_dynamic"; + } + + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + return new SkipMergeLayernormPluginDynamic(serial_data, serial_length); + } +}; + +REGISTER_TRT_PLUGIN_V2(SkipMergeLayernormPluginDynamicCreator); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 627d24a5f5..045dee09c7 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -124,6 +124,8 @@ if(WITH_GPU AND TENSORRT_FOUND) set_tests_properties(test_trt_dynamic_shape PROPERTIES TIMEOUT 120) set_tests_properties(test_trt_inspector PROPERTIES TIMEOUT 60) set_tests_properties(test_merge_layernorm_fuse_pass PROPERTIES TIMEOUT 180) + set_tests_properties(test_skip_merge_layernorm_fuse_pass PROPERTIES TIMEOUT + 180) if(WITH_NV_JETSON) set_tests_properties( test_trt_pool_op diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_skip_merge_layernorm_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_skip_merge_layernorm_fuse_pass.py new file mode 100644 index 0000000000..99fc3f8f53 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_skip_merge_layernorm_fuse_pass.py @@ -0,0 +1,250 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +import unittest +import hypothesis.strategies as st + + +class TestMergeLayernormFusePass(PassAutoScanTest): + # + # | | | | + # other_op1 other_op2 other_op1 other_op2 + # | | fuse \ / + # |------elementwise_add -> skip_merge_layernorm + # | | | | + # other_op4 merge_layernorm other_op4 other_op3 + # | + # other_op3 + + def sample_predictor_configs(self, program_config): + # trt dynamic_shape fp32 + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=1 << 20, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + {"input0_data": [1, 196, 96], "input1_data": [1, 196, 96]}, + {"input0_data": [4, 3136, 384], "input1_data": [4, 3136, 384]}, + {"input0_data": [1, 3136, 96], "input1_data": [1, 3136, 96]}, + ) + yield config, ["skip_merge_layernorm"], (1e-5, 1e-5) + # trt dynamic_shape fp16 + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=1 << 20, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Half, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + {"input0_data": [1, 196, 96], "input1_data": [1, 196, 96]}, + {"input0_data": [4, 3136, 384], "input1_data": [4, 3136, 384]}, + {"input0_data": [1, 3136, 96], "input1_data": [1, 3136, 96]}, + ) + yield config, ["skip_merge_layernorm"], (3e-3, 3e-3) + + def sample_program_config(self, draw): + batch_size = draw(st.integers(min_value=1, max_value=4)) + input_H_W = draw(st.sampled_from([56, 28, 14])) + input_n = draw(st.sampled_from([96, 192, 384])) + layernorm_40_begin_norm_axis = 2 + layernorm_40_epsilon = draw( + st.floats(min_value=0.0000001, max_value=0.001) + ) + + def generate_input(attrs): + return np.random.random( + [ + attrs[3]['batch_size'], + attrs[3]['input_H_W'] * attrs[3]['input_H_W'], + attrs[3]['input_n'], + ] + ).astype(np.float32) + + def generate_weight(attrs): + return np.random.random([attrs[3]['input_n'] * 4]).astype( + np.float32 + ) + + attrs = [ + {'shape': [-1, input_H_W, input_H_W, input_n]}, + {'shape': [-1, int(input_H_W * input_H_W / 4), int(input_n * 4)]}, + { + 'begin_norm_axis': layernorm_40_begin_norm_axis, + 'epsilon': layernorm_40_epsilon, + }, + { + 'batch_size': batch_size, + 'input_H_W': input_H_W, + 'input_n': input_n, + }, + ] + elementadd_op = OpConfig( + type="elementwise_add", + inputs={'X': ['input0_data'], 'Y': ['input1_data']}, + outputs={'Out': ['elementadd_op_out']}, + attrs={'axis': -1}, + ) + reshape2_00_op = OpConfig( + type="reshape2", + inputs={'X': ['elementadd_op_out']}, + outputs={ + 'Out': ['reshape2_00_out'], + 'XShape': ['reshape2_00_outxshape'], + }, + attrs={'shape': attrs[0]['shape']}, + ) + strided_slice_10_op = OpConfig( + type="strided_slice", + inputs={'Input': ['reshape2_00_out']}, + outputs={'Out': ['strided_slice_10_out']}, + attrs={ + 'axes': [1, 2], + 'starts': [0, 0], + 'infer_flags': [1, 1], + 'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']], + 'strides': [2, 2], + }, + ) + strided_slice_11_op = OpConfig( + type="strided_slice", + inputs={'Input': ['reshape2_00_out']}, + outputs={'Out': ['strided_slice_11_out']}, + attrs={ + 'axes': [1, 2], + 'starts': [1, 0], + 'infer_flags': [1, 1], + 'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']], + 'strides': [2, 2], + }, + ) + strided_slice_12_op = OpConfig( + type="strided_slice", + inputs={'Input': ['reshape2_00_out']}, + outputs={'Out': ['strided_slice_12_out']}, + attrs={ + 'axes': [1, 2], + 'starts': [0, 1], + 'infer_flags': [1, 1], + 'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']], + 'strides': [2, 2], + }, + ) + strided_slice_13_op = OpConfig( + type="strided_slice", + inputs={'Input': ['reshape2_00_out']}, + outputs={'Out': ['strided_slice_13_out']}, + attrs={ + 'axes': [1, 2], + 'starts': [1, 1], + 'infer_flags': [1, 1], + 'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']], + 'strides': [2, 2], + }, + ) + concat_20_op = OpConfig( + type="concat", + inputs={ + 'X': [ + 'strided_slice_10_out', + 'strided_slice_11_out', + 'strided_slice_12_out', + 'strided_slice_13_out', + ] + }, + outputs={'Out': ['concat_20_out']}, + attrs={'axis': -1}, + ) + reshape2_30_op = OpConfig( + type='reshape2', + inputs={'X': ['concat_20_out']}, + outputs={ + 'Out': ['reshape2_30_Out'], + 'XShape': ['reshape2_30_XShape'], + }, + attrs={'shape': attrs[1]['shape']}, + ) + layernorm_40_op = OpConfig( + type='layer_norm', + inputs={ + 'X': ['reshape2_30_Out'], + 'Bias': ['layer_norm_bias'], + 'Scale': ['layer_norm_scale'], + }, + outputs={ + "Y": ["layer_norm_out"], + "Mean": ["layer_norm_outMean"], + "Variance": ["layer_norm_outVariance"], + }, + attrs={ + 'begin_norm_axis': attrs[2]['begin_norm_axis'], + 'epsilon': attrs[2]['epsilon'], + }, + ) + program_config = ProgramConfig( + ops=[ + elementadd_op, + reshape2_00_op, + strided_slice_10_op, + strided_slice_11_op, + strided_slice_12_op, + strided_slice_13_op, + concat_20_op, + reshape2_30_op, + layernorm_40_op, + ], + weights={ + 'layer_norm_bias': TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + 'layer_norm_scale': TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + }, + inputs={ + 'input0_data': TensorConfig( + data_gen=partial(generate_input, attrs) + ), + 'input1_data': TensorConfig( + data_gen=partial(generate_input, attrs) + ), + }, + outputs=['layer_norm_out'], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=50, + passes=["preln_layernorm_x_fuse_pass"], + max_duration=250, + min_success_num=50, + ) + + +if __name__ == "__main__": + unittest.main() -- GitLab