diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 96ed8e663d7b0e2924586bebfaae35e91a84fc28..bb8af808d27e3cb26bc5a7245793e6efd456019a 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -129,6 +129,7 @@ if(WITH_TENSORRT) pass_library(remove_padding_recover_padding_pass inference) pass_library(delete_remove_padding_recover_padding_pass inference) pass_library(layernorm_shift_partition_fuse_pass inference) + pass_library(preln_layernorm_x_fuse_pass inference) endif() if(WITH_TENSORRT AND NOT WIN32) diff --git a/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.cc b/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..e6ab0f01ea4234dc4f4dea84fc1eaef49adc8d8a --- /dev/null +++ b/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.cc @@ -0,0 +1,189 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h" + +#include + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +class Node; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct PrelnLayerNormX : public PatternBase { + PrelnLayerNormX(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "preln_layernorm_x") {} + + void operator()(PDNode *x, PDNode *y); + // declare operator node's name + PATTERN_DECL_NODE(elementwise_bias); + PATTERN_DECL_NODE(elementwise0); + PATTERN_DECL_NODE(elementwise1); + PATTERN_DECL_NODE(layer_norm); + // declare variable node's name + PATTERN_DECL_NODE(elementwise0_out); + PATTERN_DECL_NODE(elementwise1_out); + + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); +}; + +void PrelnLayerNormX::operator()(PDNode *x, PDNode *y) { + auto *elementwise1 = + pattern->NewNode(elementwise1_repr())->assert_is_op("elementwise_add"); + auto *elementwise1_out_var = + pattern->NewNode(elementwise1_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("layernorm_shift_partition", "X"); + + elementwise1->LinksFrom({x, y}).LinksTo({elementwise1_out_var}); + // Create nodes for layer_norm op. + auto *layer_norm = pattern->NewNode(layer_norm_repr()) + ->assert_is_op("layernorm_shift_partition"); + auto *layer_norm_bias_var = + pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layernorm_shift_partition", "Bias"); + + auto *layer_norm_scale_var = + pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layernorm_shift_partition", "Scale"); + + auto *layer_norm_out_var = + pattern->NewNode(layer_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("layernorm_shift_partition", "Y"); + + // Add links for layer_norm op. + layer_norm + ->LinksFrom( + {elementwise1_out_var, layer_norm_bias_var, layer_norm_scale_var}) + .LinksTo({layer_norm_out_var}); +} + +} // namespace patterns + +int PrelnLayerNormXFusePass::ApplyPattern(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init("preln_layernorm_x_fuse", graph); + + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + PDNode *x = nullptr; + PDNode *y = nullptr; + + x = gpd.mutable_pattern() + ->NewNode("preln_layernorm_x_fuse/x") + ->AsInput() + ->assert_var_not_persistable() + ->assert_is_op_input("elementwise_add", "X"); + + y = gpd.mutable_pattern() + ->NewNode("preln_layernorm_x_fuse/y") + ->AsInput() + ->assert_var_not_persistable() + ->assert_is_op_input("elementwise_add", "Y"); + patterns::PrelnLayerNormX fused_pattern(gpd.mutable_pattern(), + "preln_layernorm_x_fuse"); + fused_pattern(x, y); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + + VLOG(4) << "handle preln layernorm x fuse"; + + GET_IR_NODE_FROM_SUBGRAPH(elementwise1, elementwise1, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + elementwise1_out, elementwise1_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + layer_norm_scale, layer_norm_scale, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern); + + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "preln_layernorm_x_fuse pass in op compat failed."; + return; + } + static int cnt = 0; + if (cnt++ > 0) { + // return; + } + std::unordered_set del_node_set; + // Create an PrelnLayerNormX op node + OpDesc new_desc(*layer_norm->Op()); + new_desc.SetType("preln_layernorm_shift_partition"); + new_desc.SetInput("X", {subgraph.at(x)->Name()}); + new_desc.SetInput("Y", {subgraph.at(y)->Name()}); + new_desc.SetOutput("Out_0", {elementwise1_out->Name()}); + new_desc.SetOutput("Out_1", {layer_norm_out->Name()}); + new_desc.RemoveOutput("Y"); + new_desc.Flush(); + + auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + + del_node_set.insert(elementwise1); + del_node_set.insert(layer_norm); + GraphSafeRemoveNodes(graph, del_node_set); + + IR_NODE_LINK_TO(subgraph.at(x), fused_node); + IR_NODE_LINK_TO(subgraph.at(y), fused_node); + IR_NODE_LINK_TO(layer_norm_scale, fused_node); + IR_NODE_LINK_TO(layer_norm_bias, fused_node); + IR_NODE_LINK_TO(fused_node, layer_norm_out); + IR_NODE_LINK_TO(fused_node, elementwise1_out); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +void PrelnLayerNormXFusePass::ApplyImpl(ir::Graph *graph) const { + FusePassBase::Init("preln_layernorm_x_fuse", graph); + int found_subgraph_count = ApplyPattern(graph); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(preln_layernorm_x_fuse_pass, + paddle::framework::ir::PrelnLayerNormXFusePass); +REGISTER_PASS_CAPABILITY(preln_layernorm_x_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().LE( + "elementwise_add", 1)); diff --git a/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h b/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..dd720e353c938dae101871ca7c21cebb8c52c61f --- /dev/null +++ b/paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +// +// | | | | +// other_op1 other_op2 other_op1 other_op2 +// | | fuse \ / +// |------elementwise_add -> preln_layernorm_shift_partition +// | | | | +// other_op4 layernorm_shift_partition other_op4 other_op3 +// | +// other_op3 +class Graph; + +class PrelnLayerNormXFusePass : public FusePassBase { + public: + PrelnLayerNormXFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({0, -1, 2}) + .End(); + } + + virtual ~PrelnLayerNormXFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + int ApplyPattern(ir::Graph* graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index f1430dc7d30182886f10ef5d40965a5a3a741f20..f49c9faeb3d19b62db676cd7b248de1656bf4679 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2271,6 +2271,7 @@ USE_TRT_CONVERTER(shape) USE_TRT_CONVERTER(fill_constant) USE_TRT_CONVERTER(fused_token_prune) USE_TRT_CONVERTER(layernorm_shift_partition) +USE_TRT_CONVERTER(preln_layernorm_shift_partition) USE_TRT_CONVERTER(merge_layernorm) USE_TRT_CONVERTER(generic_plugin_creater) USE_TRT_CONVERTER(custom_plugin_creater) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 23c173b3f58b3a2c253d739a69618b324d7012a3..59ebbb5764a5668ec6e17326ac52af414e83c9ae 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -113,6 +113,7 @@ const std::vector kTRTSubgraphPasses({ "layernorm_shift_partition_fuse_pass", // "merge_layernorm_fuse_pass", // "preln_residual_bias_fuse_pass", // + "preln_layernorm_x_fuse_pass", // // "set_transformer_input_convert_pass", // "conv_bn_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 441cc151462b7b9d119a9e4c9eb7ed2dd9e4d138..a6c0a42de37abeea4de2d6b1228d90e49d18bb61 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -77,6 +77,7 @@ list( fill_constant_op.cc fused_token_prune_op.cc layernorm_shift_partition_op.cc + preln_layernorm_shift_partition_op.cc merge_layernorm_op.cc generic_and_custom_plugin_creater.cc fused_lookup_tables_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/preln_layernorm_shift_partition_op.cc b/paddle/fluid/inference/tensorrt/convert/preln_layernorm_shift_partition_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b7719e3ac819bc51b507c0e4da7a858c8eee00a --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/preln_layernorm_shift_partition_op.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class PrelnLayerNormShiftPartitionOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert a fluid preln_layernorm_shift_partition op to tensorrt " + "preln_layernorm_shift_partition plugin"; + framework::OpDesc op_desc(op, nullptr); + + auto* X = engine_->GetITensor(op_desc.Input("X").front()); + auto* Y = engine_->GetITensor(op_desc.Input("Y").front()); + + std::vector inputs{X, Y}; + + auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front()); + auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front()); + + const float eps = op_desc.HasAttr("epsilon") + ? PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")) + : 1e-5f; + const int window_size = + PADDLE_GET_CONST(int, op_desc.GetAttr("window_size")); + const int input_resolution = + PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution")); + + const int shift_size = + op_desc.HasAttr("shift_size") + ? PADDLE_GET_CONST(int, op_desc.GetAttr("shift_size")) + : 0; + + auto* Bias_t = Bias_v->GetMutable(); + auto* Scale_t = Scale_v->GetMutable(); + + auto bias_weight = + engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *Bias_t); + auto scale_weight = + engine_->GetFp32TrtWeight(op_desc.Input("Scale").front(), *Scale_t); + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + nvinfer1::ILayer* layernorm_layer = nullptr; + if (engine_->with_dynamic_shape()) { + plugin::PrelnLnormShiftPartitionPluginDynamic* plugin = + new plugin::PrelnLnormShiftPartitionPluginDynamic( + static_cast(scale_weight.get().values), + static_cast(bias_weight.get().values), + bias_weight.get().count, + shift_size, + window_size, + input_resolution, + eps, + with_fp16); + layernorm_layer = + engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin); + } + + std::vector output_names; + output_names.emplace_back(op_desc.Output("Out_0").front()); + output_names.emplace_back(op_desc.Output("Out_1").front()); + RreplenishLayerAndOutput(layernorm_layer, + "preln_layernorm_shift_partition", + output_names, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(preln_layernorm_shift_partition, + PrelnLayerNormShiftPartitionOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 0e98afbac7f160d02beb44ddd2f32a578ad8b4e3..3e6f5779c6fa87e79f93998817df51748f1e32b5 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2100,6 +2100,15 @@ struct SimpleOpTypeSetTeller : public Teller { return false; } } + + if (op_type == "preln_layernorm_shift_partition") { + if (!with_dynamic_shape) { + VLOG(3) << "the layernorm_shift_partition does not support " + "static shape yet"; + return false; + } + } + if (op_type == "merge_layernorm") { if (!with_dynamic_shape) { VLOG(3) << "The merge_layernorm op does not support " @@ -2259,9 +2268,11 @@ struct SimpleOpTypeSetTeller : public Teller { "squeeze2", "unsqueeze2", "layernorm_shift_partition", + "preln_layernorm_shift_partition", "lookup_table", "lookup_table_v2", "expand_v2"}; + std::unordered_set teller_set{ "mul", "matmul", @@ -2376,6 +2387,7 @@ struct SimpleOpTypeSetTeller : public Teller { "unsqueeze2", "fused_token_prune", "layernorm_shift_partition", + "preln_layernorm_shift_partition", "merge_layernorm", "lookup_table", "lookup_table_v2", diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 54d50ffabc09c2cf2d1d951d9b3adc956b7d9aa5..a1544065cfce0eab579471e0ffb32b19e84ee42f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -33,9 +33,11 @@ list( preln_residual_bias_plugin.cu fused_token_prune_op_plugin.cu layernorm_shift_partition_op.cu + prelnlayernorm_shift_partition_op.cu merge_layernorm_op_plugin.cu generic_plugin.cu lookup_table.cu) + if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu many_emb_Layernorm_varseqlen_kernelMTron.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.cu b/paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..e6a90c06aaae75c40452892572943c7537e5b67b --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.cu @@ -0,0 +1,714 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "glog/logging.h" +#include "paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.h" +#include "paddle/phi/kernels/layer_norm_kernel.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#define FINAL_MASK 0xffffffff + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + + return val; +} + +template +__global__ void preln_layernorm_shift_partition(T *out0, + T *out1, + const T *input0, + const T *input1, + const T *gamma, + const T *beta, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps) { + int tid = threadIdx.x; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = + (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) + : blockIdx.y; + const int shifted_W_idx = + (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) + : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; + const int stride_of_window_H = W / window_size; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + + (shifted_W_idx % window_size); + const int output_bid = + batch_offset + window_idx * window_size * window_size + idx_in_window; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + const int index = bid * n + tid; + float local_out = 0; + if (tid < n) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + local_out = static_cast(__ldg(input0 + index)); +#else + local_out = static_cast(input0[index]); +#endif + +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + local_out += static_cast(__ldg(input1 + index)); +#else + local_out += static_cast(input1[index]); +#endif + out0[index] = local_out; + } + + mean = blockReduceSum(local_out); + if (threadIdx.x == 0) { + s_mean = mean / n; + } + __syncthreads(); + + float diff = (tid < n) ? (local_out - s_mean) : 0.0f; + variance = blockReduceSum(diff * diff); + if (threadIdx.x == 0) { + s_variance = variance / n + eps; + } + __syncthreads(); + + if (tid < n) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + out1[output_bid * n + tid] = + (T)(((local_out - s_mean) * rsqrtf(s_variance)) * + static_cast(__ldg(&gamma[tid])) + + static_cast(__ldg(&beta[tid]))); +#else + out1[output_bid * n + tid] = + (T)(((local_out - s_mean) * rsqrtf(s_variance)) * + static_cast(gamma[tid]) + + static_cast(beta[tid])); +#endif + } +} + +template <> +__global__ void preln_layernorm_shift_partition(half2 *out0_ptr, + half2 *out1_ptr, + const half2 *input0_ptr, + const half2 *input1_ptr, + const half2 *gamma_ptr, + const half2 *beta_ptr, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = + (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) + : blockIdx.y; + const int shifted_W_idx = + (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) + : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; + const int stride_of_window_H = W / window_size; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + + (shifted_W_idx % window_size); + const int output_bid = + batch_offset + window_idx * window_size * window_size + idx_in_window; + int tid = threadIdx.x; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + float2 local_out_fp2; + + float local_out = 0.0f; + int id = bid * n + tid; + if (tid < n) { + half2 tmp = __hadd2(__ldg(input0_ptr + id), __ldg(input1_ptr + id)); + local_out_fp2 = __half22float2(tmp); + local_out += local_out_fp2.x; + local_out += local_out_fp2.y; + + out0_ptr[id] = tmp; + } + + mean = blockReduceSum(local_out); + if (threadIdx.x == 0) { + s_mean = mean / (n * 2); + } + __syncthreads(); + + if (tid < n) { + variance = (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean); + variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean); + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / (n * 2) + eps); + } + __syncthreads(); + + if (tid < n) { + float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid])); + float2 beta_val = __half22float2(__ldg(&beta_ptr[tid])); + local_out_fp2.x = + (local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x; + local_out_fp2.y = + (local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y; + out1_ptr[output_bid * n + tid] = __float22half2_rn(local_out_fp2); + } +#endif +} + +#define kITE 4 +template +__global__ void preln_layernorm_shift_partition_v2(T *out0, + T *out1, + const T *__restrict input0, + const T *__restrict input1, + const T *__restrict gamma, + const T *__restrict beta, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps) { + // constexpr int kITE = 4; + const int tid = threadIdx.x; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = + (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) + : blockIdx.y; + const int shifted_W_idx = + (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) + : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; + const int stride_of_window_H = W / window_size; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + + (shifted_W_idx % window_size); + const int output_bid = + batch_offset + window_idx * window_size * window_size + idx_in_window; + const int offset = bid * n; + const int output_offset = output_bid * n; + + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + float local_out[kITE]; + + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + int index = offset + col_id; + if (col_id < n) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + local_out[i] = static_cast(__ldg(input0 + index)); +#else + local_out[i] = static_cast(input0[index]); +#endif + +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + local_out[i] += static_cast(__ldg(input1 + index)); +#else + local_out[i] += static_cast(input1[index]); +#endif + out0[index] = local_out[i]; + sum += local_out[i]; + } + } + + mean = blockReduceSum(sum); + if (tid == 0) { + s_mean = mean / n; + } + __syncthreads(); + + float var = 0.0f; +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + float diff = local_out[i] - s_mean; + local_out[i] = diff; + var += diff * diff; + } + } + + variance = blockReduceSum(var); + if (tid == 0) { + s_variance = rsqrtf(variance / n + eps); + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + out1[output_offset + col_id] = + (T)(local_out[i] * s_variance * + static_cast(__ldg(&gamma[col_id])) + + static_cast(__ldg(&beta[col_id]))); +#else + out1[output_offset + col_id] = + (T)(local_out[i] * s_variance * static_cast(gamma[col_id]) + + static_cast(beta[col_id])); +#endif + } + } +} + +template <> +__global__ void preln_layernorm_shift_partition_v2( + half2 *out0_ptr, + half2 *out1_ptr, + const half2 *__restrict input0_ptr, + const half2 *__restrict input1_ptr, + const half2 *__restrict gamma_ptr, + const half2 *__restrict beta_ptr, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + // constexpr int ite = 4; + const int tid = threadIdx.x; + const int batch_offset = blockIdx.z * gridDim.y * gridDim.x; + const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x; + const int shifted_H_idx = + (shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y) + : blockIdx.y; + const int shifted_W_idx = + (shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x) + : blockIdx.x; + const int window_H_idx = shifted_H_idx / window_size; + const int window_W_idx = shifted_W_idx / window_size; + const int stride_of_window_H = W / window_size; + const int window_idx = window_H_idx * stride_of_window_H + window_W_idx; + const int idx_in_window = (shifted_H_idx % window_size) * window_size + + (shifted_W_idx % window_size); + const int output_bid = + batch_offset + window_idx * window_size * window_size + idx_in_window; + const int offset = bid * n; + const int output_offset = output_bid * n; + __shared__ float s_mean; + __shared__ float s_variance; + float mean = 0.0f; + float variance = 0.0f; + half2 local_out_half2[kITE]; + + // float sum = 0.0f; + half2 sum = __float2half2_rn(0.0f); +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + int index = offset + col_id; + local_out_half2[i] = __ldg(input0_ptr + index); + local_out_half2[i] = + __hadd2(local_out_half2[i], __ldg(input1_ptr + index)); + out0_ptr[i] = local_out_half2[i]; + sum += local_out_half2[i]; + } + } + + mean = blockReduceSum(static_cast(sum.x + sum.y)); + if (threadIdx.x == 0) { + s_mean = mean / (n * 2); + } + __syncthreads(); + + float var = 0.0f; + half2 s_mean_2 = __float2half2_rn(s_mean); +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + local_out_half2[i] = local_out_half2[i] - s_mean_2; + float v1 = static_cast(local_out_half2[i].x); + float v2 = static_cast(local_out_half2[i].y); + var += v1 * v1 + v2 * v2; + } + } + + variance = blockReduceSum(var); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / (n * 2) + eps); + } + __syncthreads(); + + half2 s_var_2 = __float2half2_rn(s_variance); +#pragma unroll + for (int i = 0; i < kITE; i++) { + int col_id = i * blockDim.x + tid; + if (col_id < n) { + out1_ptr[output_offset + col_id] = + local_out_half2[i] * s_var_2 * __ldg(&gamma_ptr[col_id]) + + __ldg(&beta_ptr[col_id]); + } + } +#endif +} + +template +void invokePrelnLayernormShiftPartition(T *out0, + T *out1, + const T *input0, + const T *input1, + const T *gamma, + const T *beta, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps, + cudaStream_t stream) { + dim3 grid(W, H, batch); + int blockSize = (n + 31) / 32 * 32; + if (blockSize >= 768) { + blockSize = ((blockSize / 4) + 31) / 32 * 32; + preln_layernorm_shift_partition_v2 + <<>>(out0, + out1, + input0, + input1, + gamma, + beta, + batch, + H, + W, + n, + shift_size, + window_size, + eps); + } else { + preln_layernorm_shift_partition + <<>>(out0, + out1, + input0, + input1, + gamma, + beta, + batch, + H, + W, + n, + shift_size, + window_size, + eps); + } +} + +template <> +void invokePrelnLayernormShiftPartition(half *out0, + half *out1, + const half *input0, + const half *input1, + const half *gamma, + const half *beta, + int batch, + int H, + int W, + int n, + int shift_size, + int window_size, + const float eps, + cudaStream_t stream) { + dim3 grid(W, H, batch); + int blockSize = n / 2; + blockSize = (blockSize + 31) / 32 * 32; + + if ((batch * H * W >= 512 && blockSize >= 768) || blockSize > 1024) { + blockSize = ((blockSize / 4) + 31) / 32 * 32; + preln_layernorm_shift_partition_v2<<>>( + reinterpret_cast(out0), + reinterpret_cast(out1), + (const half2 *)input0, + (const half2 *)input1, + (const half2 *)gamma, + (const half2 *)beta, + batch, + H, + W, + n / 2, + shift_size, + window_size, + eps); + } else { + preln_layernorm_shift_partition<<>>( + reinterpret_cast(out0), + reinterpret_cast(out1), + (const half2 *)input0, + (const half2 *)input1, + (const half2 *)gamma, + (const half2 *)beta, + batch, + H, + W, + n / 2, + shift_size, + window_size, + eps); + } +} + +template +static void convertAndCopy(const std::vector &host, T *dev) { + T *host_ptr = new T[host.size()]; + std::transform(host.begin(), host.end(), host_ptr, [](float x) { + return static_cast(x); + }); + cudaMemcpy(dev, host_ptr, sizeof(T) * host.size(), cudaMemcpyHostToDevice); + delete host_ptr; +} + +void PrelnLnormShiftPartitionPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT {} + +PrelnLnormShiftPartitionPluginDynamic::PrelnLnormShiftPartitionPluginDynamic( + const float *gamma, + const float *beta, + const int param_num, + int shift_size, + int window_size, + int input_resolution, + float eps, + bool with_fp16, + std::shared_ptr gamma_dev, + std::shared_ptr beta_dev) + : with_fp16_(with_fp16), + window_size_(window_size), + shift_size_(shift_size), + input_resolution_(input_resolution), + eps_(eps), + param_num_(param_num), + gamma_dev_(gamma_dev), + beta_dev_(beta_dev) { + beta_.resize(param_num); + gamma_.resize(param_num); + std::copy(gamma, gamma + param_num, gamma_.data()); + std::copy(beta, beta + param_num, beta_.data()); + int type_size = with_fp16 ? sizeof(half) : sizeof(float); + if (gamma_dev_ == nullptr) { + void *p; + cudaMalloc(reinterpret_cast(&p), param_num_ * type_size); + gamma_dev_.reset(p, [](void *ptr) { cudaFree(ptr); }); + if (with_fp16) + convertAndCopy(gamma_, reinterpret_cast(p)); + else + convertAndCopy(gamma_, reinterpret_cast(p)); + } + if (beta_dev_ == nullptr) { + void *p; + cudaMalloc(reinterpret_cast(&p), param_num_ * type_size); + beta_dev_.reset(p, [](void *ptr) { cudaFree(ptr); }); + if (with_fp16) + convertAndCopy(beta_, reinterpret_cast(p)); + else + convertAndCopy(beta_, reinterpret_cast(p)); + } +} + +PrelnLnormShiftPartitionPluginDynamic::PrelnLnormShiftPartitionPluginDynamic( + void const *serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &beta_); + DeserializeValue(&serialData, &serialLength, &gamma_); + DeserializeValue(&serialData, &serialLength, ¶m_num_); + DeserializeValue(&serialData, &serialLength, &with_fp16_); + DeserializeValue(&serialData, &serialLength, &shift_size_); + DeserializeValue(&serialData, &serialLength, &window_size_); + DeserializeValue(&serialData, &serialLength, &input_resolution_); + DeserializeValue(&serialData, &serialLength, &eps_); + int type_size = with_fp16_ ? sizeof(half) : sizeof(float); + { + void *p; + cudaMalloc(reinterpret_cast(&p), param_num_ * type_size); + gamma_dev_.reset(p, [](void *ptr) { cudaFree(ptr); }); + if (with_fp16_) + convertAndCopy(gamma_, reinterpret_cast(p)); + else + convertAndCopy(gamma_, reinterpret_cast(p)); + } + { + void *p; + cudaMalloc(reinterpret_cast(&p), param_num_ * type_size); + beta_dev_.reset(p, [](void *ptr) { cudaFree(ptr); }); + if (with_fp16_) + convertAndCopy(beta_, reinterpret_cast(p)); + else + convertAndCopy(beta_, reinterpret_cast(p)); + } +} + +bool PrelnLnormShiftPartitionPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc *in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + if (with_fp16_) { + return in.type == nvinfer1::DataType::kHALF && + in.format == nvinfer1::TensorFormat::kLINEAR; + } else { + return in.type == nvinfer1::DataType::kFLOAT && + in.format == nvinfer1::TensorFormat::kLINEAR; + } + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType PrelnLnormShiftPartitionPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType *input_types, + int nb_inputs) const TRT_NOEXCEPT { + return input_types[0]; +} + +nvinfer1::DimsExprs PrelnLnormShiftPartitionPluginDynamic::getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs *inputs, + int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { + if (output_index == 0) return inputs[0]; + + nvinfer1::DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = expr_builder.operation( + nvinfer1::DimensionOperation::kFLOOR_DIV, + *expr_builder.operation(nvinfer1::DimensionOperation::kPROD, + *inputs[0].d[0], + *inputs[0].d[1]), + *expr_builder.constant(window_size_ * window_size_)); + ret.d[1] = expr_builder.constant(window_size_ * window_size_); + ret.d[2] = inputs[0].d[2]; + return ret; +} + +int PrelnLnormShiftPartitionPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, + const void *const *inputs, + void *const *outputs, + void *workspace, + cudaStream_t stream) TRT_NOEXCEPT { + const auto &input_dims = input_desc[0].dims; + auto input_type = input_desc[0].type; + int batch = input_dims.d[0]; + int emb_dim = input_dims.d[2]; + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(3) + << "TRT Plugin DataType selected. PreLayernormShiftPartition-->fp32"; + invokePrelnLayernormShiftPartition( + reinterpret_cast(outputs[0]), + reinterpret_cast(outputs[1]), + reinterpret_cast(inputs[0]), + reinterpret_cast(inputs[1]), + reinterpret_cast(gamma_dev_.get()), + reinterpret_cast(beta_dev_.get()), + batch, + input_resolution_, + input_resolution_, + emb_dim, + shift_size_, + window_size_, + eps_, + stream); + } else if (input_type == nvinfer1::DataType::kHALF) { + VLOG(3) + << "TRT Plugin DataType selected. PreLayernormShiftPartition-->half"; + invokePrelnLayernormShiftPartition( + reinterpret_cast(outputs[0]), + reinterpret_cast(outputs[1]), + reinterpret_cast(inputs[0]), + reinterpret_cast(inputs[1]), + reinterpret_cast(gamma_dev_.get()), + reinterpret_cast(beta_dev_.get()), + batch, + input_resolution_, + input_resolution_, + emb_dim, + shift_size_, + window_size_, + eps_, + stream); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.h b/paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e832b2144e274bd475a78627a60ee497f37d65b9 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.h @@ -0,0 +1,158 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class PrelnLnormShiftPartitionPluginDynamic : public DynamicPluginTensorRT { + public: + PrelnLnormShiftPartitionPluginDynamic( + const float* gamma, + const float* beta, + const int param_num, + int shift_size, + int window_size, + int input_resolution, + float eps, + bool with_fp16, + std::shared_ptr gamma_dev = nullptr, + std::shared_ptr beta_dev = nullptr); + + PrelnLnormShiftPartitionPluginDynamic(void const* serialData, + size_t serialLength); + + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + return new PrelnLnormShiftPartitionPluginDynamic(gamma_.data(), + beta_.data(), + beta_.size(), + shift_size_, + window_size_, + input_resolution_, + eps_, + with_fp16_, + gamma_dev_, + beta_dev_); + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "prelnlnorm_shift_partition_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 2; } + int initialize() TRT_NOEXCEPT override { return 0; } + + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(beta_) + SerializedSize(gamma_) + + SerializedSize(param_num_) + SerializedSize(with_fp16_) + + SerializedSize(shift_size_) + SerializedSize(window_size_) + + SerializedSize(input_resolution_) + SerializedSize(eps_); + } + + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, beta_); + SerializeValue(&buffer, gamma_); + SerializeValue(&buffer, param_num_); + SerializeValue(&buffer, with_fp16_); + SerializeValue(&buffer, shift_size_); + SerializeValue(&buffer, window_size_); + SerializeValue(&buffer, input_resolution_); + SerializeValue(&buffer, eps_); + } + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) // NOLINT + TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + private: + bool with_fp16_; + std::vector gamma_; + std::vector beta_; + int window_size_; + int shift_size_; + int input_resolution_; + int param_num_; + float eps_; + std::shared_ptr gamma_dev_; + std::shared_ptr beta_dev_; +}; + +class PrelnLnormShiftPartitionPluginDynamicCreator + : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "prelnlnorm_shift_partition_dynamic"; + } + + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + return new PrelnLnormShiftPartitionPluginDynamic(serial_data, + serial_length); + } +}; + +REGISTER_TRT_PLUGIN_V2(PrelnLnormShiftPartitionPluginDynamicCreator); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 21750232f88ee7a367df65e65d4a03b7d43937ab..c5b6e0f2be67f5a9f97827b5d70f8e54cccfae0f 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -173,6 +173,8 @@ if(WITH_GPU AND TENSORRT_FOUND) set_tests_properties(test_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_reshape2_matmul_fuse_pass PROPERTIES TIMEOUT 240) + set_tests_properties(test_preln_layernorm_x_fuse_pass PROPERTIES TIMEOUT + 240) set_tests_properties(test_trt_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_trt_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_preln_layernorm_x_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_preln_layernorm_x_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1444b5aecd5babde10388244aa97cb0e1dac2d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_preln_layernorm_x_fuse_pass.py @@ -0,0 +1,274 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +import unittest + +import hypothesis.strategies as st + + +class TestLayernormShiftPartitionPass(PassAutoScanTest): + # + # | | | | + # other_op1 other_op2 other_op1 other_op2 + # | | fuse \ / + # |------elementwise_add -> preln_layernorm_shift_partition + # | | | | + # other_op4 layernorm_shift_partition other_op4 other_op3 + # | + # other_op3 + + def sample_predictor_configs(self, program_config): + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + { + "input_data_x": [1, 9, 96], + "input_data_y": [1, 9, 96], + }, + { + "input_data_x": [4, 3136, 768], + "input_data_y": [4, 3136, 768], + }, + { + "input_data_x": [1, 784, 384], + "input_data_y": [1, 784, 384], + }, + ) + yield config, ['preln_layernorm_shift_partition'], (1e-5, 1e-5) + + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Half, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + { + "input_data_x": [1, 9, 96], + "input_data_y": [1, 9, 96], + }, + { + "input_data_x": [4, 3136, 768], + "input_data_y": [4, 3136, 768], + }, + { + "input_data_x": [1, 784, 384], + "input_data_y": [1, 784, 384], + }, + ) + yield config, ['preln_layernorm_shift_partition'], (1e-2, 1e-2) + + def sample_program_config(self, draw): + axis = [0, 1, 3, 2, 4, 5] + epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001)) + # begin_norm_axis has to be 2 + begin_norm_axis = 2 + batch_size = draw(st.integers(min_value=1, max_value=4)) + + window_size = draw(st.sampled_from([3, 5, 7])) + move_shape = draw(st.integers(min_value=1, max_value=8)) + dim = draw(st.sampled_from([96, 192, 384, 768])) + + def generate_input(attrs): + return np.random.random( + [attrs[1]["batch_size"], *attrs[1]["input_dim"]] + ).astype(np.float32) + + def generate_weight(attrs): + return np.random.random(attrs[1]['input_dim'][-1]).astype( + np.float32 + ) + + attrs = [ + { + 'begin_norm_axis': begin_norm_axis, + 'epsilon': epsilon, + }, + { + 'batch_size': batch_size, + 'input_dim': [(window_size * move_shape) ** 2, dim], + }, + { + 'axis': axis, + 'input_resolution': window_size * move_shape, + 'move_shape': move_shape, + 'window_size': window_size, + }, + ] + + elementwise_add_op = OpConfig( + type="elementwise_add", + inputs={"X": ["input_data_x"], "Y": ["input_data_y"]}, + outputs={"Out": ["ele_out"]}, + attrs={"axis": -1}, + ) + layer_norm_op = OpConfig( + type="layer_norm", + inputs={ + "X": ["ele_out"], + "Bias": ["layer_norm_bias"], + "Scale": ["layer_norm_scale"], + }, + outputs={ + "Y": ["layer_norm_output1"], + "Mean": ["layer_norm_output2"], + "Variance": ["layer_norm_output3"], + }, + attrs={ + "begin_norm_axis": attrs[0]["begin_norm_axis"], + "epsilon": attrs[0]["epsilon"], + }, + ) + reshape_op2 = OpConfig( + type="reshape2", + inputs={ + "X": ["layer_norm_output1"], + }, + outputs={ + "Out": ["reshape_output2"], + "XShape": ["reshape_output2_xshape"], + }, + attrs={ + 'shape': [ + -1, + attrs[2]["input_resolution"], + attrs[2]["input_resolution"], + attrs[1]["input_dim"][-1], + ] + }, + ) + reshape_op3 = OpConfig( + type="reshape2", + inputs={ + "X": ["reshape_output2"], + }, + outputs={ + "Out": ["reshape_output3"], + "XShape": ["reshape_output3_xshape"], + }, + attrs={ + 'shape': [ + -1, + attrs[2]["move_shape"], + attrs[2]["window_size"], + attrs[2]["move_shape"], + attrs[2]["window_size"], + attrs[1]["input_dim"][-1], + ] + }, + ) + transpose_op4 = OpConfig( + type='transpose2', + inputs={ + "X": ["reshape_output3"], + }, + outputs={"Out": ["transpose_output4"]}, + attrs={"axis": attrs[2]['axis']}, + ) + reshape_op5 = OpConfig( + type="reshape2", + inputs={ + "X": ["transpose_output4"], + }, + outputs={ + "Out": ["reshape_output5"], + "XShape": ["reshape_output5_xshape"], + }, + attrs={ + 'shape': [ + -1, + attrs[2]["window_size"], + attrs[2]["window_size"], + attrs[1]["input_dim"][-1], + ] + }, + ) + reshape_op6 = OpConfig( + type="reshape2", + inputs={ + "X": ["reshape_output5"], + }, + outputs={ + "Out": ["reshape_output6"], + "XShape": ["reshape_output6_xshape"], + }, + attrs={ + 'shape': [ + -1, + attrs[2]["window_size"] ** 2, + attrs[1]["input_dim"][-1], + ] + }, + ) + + program_config = ProgramConfig( + ops=[ + elementwise_add_op, + layer_norm_op, + reshape_op2, + reshape_op3, + transpose_op4, + reshape_op5, + reshape_op6, + ], + weights={ + "layer_norm_bias": TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + "layer_norm_scale": TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + }, + inputs={ + "input_data_x": TensorConfig( + data_gen=partial(generate_input, attrs) + ), + "input_data_y": TensorConfig( + data_gen=partial(generate_input, attrs) + ), + }, + outputs=["ele_out", "reshape_output6"], + ) + + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=50, + passes=["preln_layernorm_x_fuse_pass"], + max_duration=250, + min_success_num=50, + ) + + +if __name__ == "__main__": + unittest.main()