未验证 提交 b2bb7ec9 编写于 作者: W Wang Bojun 提交者: GitHub

[TRT] Transpose layernorm fusion with different input format (#50082)

* trans_layernorm
上级 b3f60f39
...@@ -145,6 +145,7 @@ if(WITH_TENSORRT) ...@@ -145,6 +145,7 @@ if(WITH_TENSORRT)
pass_library(elementwise_groupnorm_act_pass inference) pass_library(elementwise_groupnorm_act_pass inference)
pass_library(preln_elementwise_groupnorm_act_pass inference) pass_library(preln_elementwise_groupnorm_act_pass inference)
pass_library(groupnorm_act_pass inference) pass_library(groupnorm_act_pass inference)
pass_library(trans_layernorm_fuse_pass inference)
pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference) pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
endif() endif()
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/trans_layernorm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct TransLayernormPattern : public PatternBase {
TransLayernormPattern(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "trans_layernorm") {}
void operator()(PDNode *x);
PATTERN_DECL_NODE(transpose);
PATTERN_DECL_NODE(transpose_output);
PATTERN_DECL_NODE(reshape);
PATTERN_DECL_NODE(reshape_output);
PATTERN_DECL_NODE(layernorm);
PATTERN_DECL_NODE(layernorm_scale);
PATTERN_DECL_NODE(layernorm_bias);
PATTERN_DECL_NODE(layernorm_output);
};
void TransLayernormPattern::operator()(PDNode *x) {
std::unordered_set<std::string> reshape_ops{"reshape2",
"flatten_contiguous_range"};
auto *transpose =
pattern->NewNode(transpose_repr())->assert_is_op("transpose2");
auto *transpose_output = pattern->NewNode(transpose_output_repr())
->assert_is_op_output("transpose2")
->assert_is_ops_input(reshape_ops, "X");
transpose->LinksFrom({x}).LinksTo({transpose_output});
auto *reshape = pattern->NewNode(reshape_repr())->assert_is_ops(reshape_ops);
auto *reshape_output = pattern->NewNode(reshape_output_repr())
->assert_is_ops_output(reshape_ops, "Out")
->assert_is_op_input("layer_norm", "X")
->AsOutput();
reshape->LinksFrom({transpose_output}).LinksTo({reshape_output});
auto *layernorm =
pattern->NewNode(layernorm_repr())->assert_is_op("layer_norm");
auto *layernorm_scale = pattern->NewNode(layernorm_scale_repr())
->assert_is_op_input("layer_norm", "Scale")
->AsInput();
auto *layernorm_bias = pattern->NewNode(layernorm_bias_repr())
->assert_is_op_input("layer_norm", "Bias")
->AsInput();
auto *layernorm_output = pattern->NewNode(layernorm_output_repr())
->assert_is_op_output("layer_norm", "Y")
->AsOutput();
layernorm->LinksFrom({reshape_output, layernorm_scale, layernorm_bias})
.LinksTo({layernorm_output});
}
} // namespace patterns
// this pass make a fusion as below:
//
// |
// transpose(axis= [0,2,3,1])
// |
// reshape(n,h*w,c)
// | |
// out layernorm(begin_norm_axis=2 or -1)
// |
// layernorm_out
//
// ->fuse to
//
// |
// trans_layernorm
// | |
// out layernorm_out
//
int TransLayernormFusePass::ApplyConvTransLayernormPattern(
ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("trans_layernorm_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
PDNode *x = nullptr;
x = gpd.mutable_pattern()
->NewNode("trans_layernorm_fuse/x")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("transpose2", "X");
patterns::TransLayernormPattern fused_pattern(gpd.mutable_pattern(),
"trans_layernorm_fuse");
fused_pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
VLOG(4) << "handle transpose layernorm fuse";
GET_IR_NODE_FROM_SUBGRAPH(transpose, transpose, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose_output, transpose_output, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape, reshape, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_output, reshape_output, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layernorm, layernorm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layernorm_scale, layernorm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layernorm_bias, layernorm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layernorm_output, layernorm_output, fused_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "transpose layernorm pass in op compat failed.";
return;
}
// trans_layernorm is suit for nchw-to-nhwc transpose before layernorm
// check for it
std::vector<int> trans_axis =
PADDLE_GET_CONST(std::vector<int>, transpose->Op()->GetAttr("axis"));
if (trans_axis != std::vector<int>{0, 2, 3, 1}) {
VLOG(1) << "transpose layernorm fuse pass, transpose axis check fail, "
"stop fusion";
return;
}
if (reshape->Op()->Type() == "flatten_contiguous_range") {
int start_axis =
PADDLE_GET_CONST(int, reshape->Op()->GetAttr("start_axis"));
int stop_axis =
PADDLE_GET_CONST(int, reshape->Op()->GetAttr("stop_axis"));
if (!(start_axis == 1 && stop_axis == 2)) {
VLOG(1) << "transpose layernorm fuse pass, flatten axis check fail, "
"stop fusion";
return;
}
} else if (reshape->Op()->Type() == "reshape2") {
std::vector<int> reshape_shape =
PADDLE_GET_CONST(std::vector<int>, reshape->Op()->GetAttr("shape"));
if (reshape_shape.size() != 3) {
VLOG(1)
<< "transpose layernorm fuse pass, reshape check fail, stop fusion";
return;
}
}
auto layernorm_begin_norm_axis =
PADDLE_GET_CONST(int, layernorm->Op()->GetAttr("begin_norm_axis"));
if (layernorm_begin_norm_axis != 2 && layernorm_begin_norm_axis != -1) {
VLOG(1) << "transpose layernorm fuse pass, layernorm begin norm axis "
"check fail, stop fusion";
return;
}
std::unordered_set<const Node *> del_node_set;
// Create an preln_groupnorm_act op node
OpDesc new_desc(*layernorm->Op());
new_desc.SetType("trans_layernorm");
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetOutput("Out_reshape", {reshape_output->Name()});
new_desc.SetOutput("Out_layernorm", {layernorm_output->Name()});
new_desc.RemoveOutput("Y");
new_desc.Flush();
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(transpose);
del_node_set.insert(transpose_output);
del_node_set.insert(reshape);
del_node_set.insert(layernorm);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(layernorm_scale, fused_node);
IR_NODE_LINK_TO(layernorm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, reshape_output);
IR_NODE_LINK_TO(fused_node, layernorm_output);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void TransLayernormFusePass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init("trans_layernorm_fuse_pass", graph);
int found_subgraph_count = ApplyConvTransLayernormPattern(graph);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(trans_layernorm_fuse_pass,
paddle::framework::ir::TransLayernormFusePass);
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
// This pass aim to fuse below structure
//
// |
// transpose(axis= [0,2,3,1])
// |
// reshape(n,h*w,c)
// | |
// out layernorm(begin_norm_axis=2 or -1)
// |
// layernorm_out
//
// ->fuse to
//
// |
// trans_layernorm
// | |
// out layernorm_out
//
class Graph;
class TransLayernormFusePass : public FusePassBase {
public:
TransLayernormFusePass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("flatten_contiguous_range"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("start_axis")
.IsNumEQ(1)
.End()
.AddAttr("stop_axis")
.IsNumEQ(2)
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End();
}
virtual ~TransLayernormFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
int ApplyConvTransLayernormPattern(ir::Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -2490,6 +2490,7 @@ USE_TRT_CONVERTER(layernorm_shift_partition) ...@@ -2490,6 +2490,7 @@ USE_TRT_CONVERTER(layernorm_shift_partition)
USE_TRT_CONVERTER(reverse_roll) USE_TRT_CONVERTER(reverse_roll)
USE_TRT_CONVERTER(preln_layernorm_shift_partition) USE_TRT_CONVERTER(preln_layernorm_shift_partition)
USE_TRT_CONVERTER(merge_layernorm) USE_TRT_CONVERTER(merge_layernorm)
USE_TRT_CONVERTER(trans_layernorm)
USE_TRT_CONVERTER(skip_merge_layernorm) USE_TRT_CONVERTER(skip_merge_layernorm)
USE_TRT_CONVERTER(generic_plugin_creater) USE_TRT_CONVERTER(generic_plugin_creater)
USE_TRT_CONVERTER(custom_plugin_creater) USE_TRT_CONVERTER(custom_plugin_creater)
......
...@@ -127,6 +127,10 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -127,6 +127,10 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_map_matmul_to_mul_pass", // "trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
#if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading.
#else
"trans_layernorm_fuse_pass", //
#endif
"remove_padding_recover_padding_pass", // "remove_padding_recover_padding_pass", //
"delete_remove_padding_recover_padding_pass", // "delete_remove_padding_recover_padding_pass", //
// "yolo_box_fuse_pass", // // "yolo_box_fuse_pass", //
......
...@@ -92,6 +92,7 @@ list( ...@@ -92,6 +92,7 @@ list(
take_along_axis_op.cc take_along_axis_op.cc
logsigmoid_op.cc logsigmoid_op.cc
preln_layernorm_shift_partition_op.cc preln_layernorm_shift_partition_op.cc
trans_layernorm_op.cc
merge_layernorm_op.cc merge_layernorm_op.cc
skip_merge_layernorm_op.cc skip_merge_layernorm_op.cc
generic_and_custom_plugin_creater.cc generic_and_custom_plugin_creater.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/trans_layernorm_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class TransLayerNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a trans_layer_norm fused op to tensorrt "
"trans_layernorm plugin";
framework::OpDesc op_desc(op, nullptr);
auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front());
auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front());
// we already check the begin_norm_axis in pass action.
// here we set begin_norm_axis as 3 to fit the calculation in trt plugin.
const int begin_norm_axis = 3;
const float eps = op_desc.HasAttr("epsilon")
? PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"))
: 1e-5f;
PADDLE_ENFORCE_NOT_NULL(
Bias_v,
platform::errors::InvalidArgument(
"Input(Bias) of layer_norm should not be null."));
PADDLE_ENFORCE_NOT_NULL(
Scale_v,
platform::errors::InvalidArgument(
"Input(Scale) of layer_norm should not be null."));
auto* Bias_t = Bias_v->GetMutable<phi::DenseTensor>();
auto* Scale_t = Scale_v->GetMutable<phi::DenseTensor>();
auto bias_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *Bias_t);
auto scale_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Scale").front(), *Scale_t);
nvinfer1::ILayer* layernorm_layer = nullptr;
if (engine_->with_dynamic_shape()) {
// For dynamic shape,
// the shape of mean and variance will be determine in configuPlugin.
std::vector<int64_t> mean_shape{1};
std::vector<int64_t> variance_shape{1};
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::TransLayerNormPluginDynamic* plugin =
new plugin::TransLayerNormPluginDynamic(
static_cast<const float*>(bias_weight.get().values),
bias_weight.get().count,
static_cast<const float*>(scale_weight.get().values),
scale_weight.get().count,
begin_norm_axis,
eps,
mean_shape,
variance_shape,
with_fp16);
layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"trans_layernorm do not support static shape mode yet"));
}
auto output_layernorm_name = op_desc.Output("Out_layernorm").front();
auto output_reshape_name = op_desc.Output("Out_reshape").front();
RreplenishLayerAndOutput(layernorm_layer,
"trans_layernorm",
{output_layernorm_name, output_reshape_name},
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(trans_layernorm, TransLayerNormOpConverter);
...@@ -2489,7 +2489,13 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2489,7 +2489,13 @@ struct SimpleOpTypeSetTeller : public Teller {
return false; return false;
} }
} }
if (op_type == "trans_layernorm") {
if (!with_dynamic_shape) {
VLOG(3) << "The trans_layernorm op does not support "
"static shape yet";
return false;
}
}
if (op_type == "lookup_table") { if (op_type == "lookup_table") {
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
VLOG(3) << "the lookup_table does not support " VLOG(3) << "the lookup_table does not support "
...@@ -2659,6 +2665,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2659,6 +2665,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"logsigmoid", "logsigmoid",
"preln_layernorm_shift_partition", "preln_layernorm_shift_partition",
"lookup_table", "lookup_table",
"trans_layernorm",
"merge_layernorm", "merge_layernorm",
"skip_merge_layernorm", "skip_merge_layernorm",
"lookup_table_v2", "lookup_table_v2",
...@@ -2808,6 +2815,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2808,6 +2815,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"take_along_axis", "take_along_axis",
"logsigmoid", "logsigmoid",
"preln_layernorm_shift_partition", "preln_layernorm_shift_partition",
"trans_layernorm",
"merge_layernorm", "merge_layernorm",
"skip_merge_layernorm", "skip_merge_layernorm",
"lookup_table", "lookup_table",
......
...@@ -33,6 +33,7 @@ list( ...@@ -33,6 +33,7 @@ list(
layernorm_shift_partition_op.cu layernorm_shift_partition_op.cu
reverse_roll_op_plugin.cu reverse_roll_op_plugin.cu
prelnlayernorm_shift_partition_op.cu prelnlayernorm_shift_partition_op.cu
trans_layernorm_op_plugin.cu
merge_layernorm_op_plugin.cu merge_layernorm_op_plugin.cu
skip_merge_layernorm_op_plugin.cu skip_merge_layernorm_op_plugin.cu
skip_groupnorm_act_op_plugin.cu skip_groupnorm_act_op_plugin.cu
......
...@@ -26,12 +26,21 @@ ...@@ -26,12 +26,21 @@
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h" #include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
inline int getSMVersion() {
const int device = phi::backends::gpu::GetCurrentDeviceId();
const phi::gpuDeviceProp prop =
phi::backends::gpu::GetDeviceProperties(device);
return prop.major * 10 + prop.minor;
}
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
#define FINAL_MASK 0xffffffff #define FINAL_MASK 0xffffffff
...@@ -425,7 +434,10 @@ int PrelnResidualBiasPluginDynamic::enqueue( ...@@ -425,7 +434,10 @@ int PrelnResidualBiasPluginDynamic::enqueue(
float *mean = nullptr; float *mean = nullptr;
float *var = nullptr; float *var = nullptr;
const int VecSize = 8; const int VecSize = 8;
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int sm = getSMVersion();
// sm >= 60 to support _ldg
if (sm >= 60) {
// if hidden is even, use half2 kernel generalAddBiasResidualLayerNormOpt2 // if hidden is even, use half2 kernel generalAddBiasResidualLayerNormOpt2
if (hidden % 2 == 0) { if (hidden % 2 == 0) {
int half_n = hidden / 2; int half_n = hidden / 2;
...@@ -479,7 +491,8 @@ int PrelnResidualBiasPluginDynamic::enqueue( ...@@ -479,7 +491,8 @@ int PrelnResidualBiasPluginDynamic::enqueue(
var, var,
stream); stream);
} }
#else } else {
// if sm < 60, use FusedLayernormResidualDropoutBiasFunctor only
paddle::operators::FusedLayernormResidualDropoutBiasFunctor<half, paddle::operators::FusedLayernormResidualDropoutBiasFunctor<half,
uint8_t, uint8_t,
VecSize, VecSize,
...@@ -504,7 +517,7 @@ int PrelnResidualBiasPluginDynamic::enqueue( ...@@ -504,7 +517,7 @@ int PrelnResidualBiasPluginDynamic::enqueue(
mean, mean,
var, var,
stream); stream);
#endif }
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) tensorRT plugin should be " "The Ernie(Bert) tensorRT plugin should be "
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class TransLayerNormPluginDynamic : public DynamicPluginTensorRT {
public:
TransLayerNormPluginDynamic(const float* bias,
const int bias_num,
const float* scale,
const int scale_num,
int begin_norm_axis,
float eps,
std::vector<int64_t> mean_shape,
std::vector<int64_t> variance_shape,
bool with_fp16)
: begin_norm_axis_(begin_norm_axis),
eps_(eps),
mean_shape_(mean_shape),
variance_shape_(variance_shape) {
with_fp16_ = with_fp16;
bias_.resize(bias_num);
scale_.resize(scale_num);
std::copy(bias, bias + bias_num, bias_.data());
std::copy(scale, scale + scale_num, scale_.data());
}
TransLayerNormPluginDynamic(void const* serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &bias_);
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &begin_norm_axis_);
DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &mean_shape_);
DeserializeValue(&serialData, &serialLength, &variance_shape_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
auto ptr = new TransLayerNormPluginDynamic(bias_.data(),
bias_.size(),
scale_.data(),
scale_.size(),
begin_norm_axis_,
eps_,
mean_shape_,
variance_shape_,
with_fp16_);
ptr->bias_gpu_ = bias_gpu_;
ptr->scale_gpu_ = scale_gpu_;
ptr->fp16_bias_gpu_ = fp16_bias_gpu_;
ptr->fp16_scale_gpu_ = fp16_scale_gpu_;
return ptr;
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "trans_layernorm_plugin_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 2; }
int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(bias_) + SerializedSize(scale_) +
SerializedSize(begin_norm_axis_) + SerializedSize(eps_) +
SerializedSize(mean_shape_) + SerializedSize(variance_shape_) +
SerializedSize(with_fp16_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, bias_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, begin_norm_axis_);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, mean_shape_);
SerializeValue(&buffer, variance_shape_);
SerializeValue(&buffer, with_fp16_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
private:
std::vector<float> bias_;
std::vector<float> scale_;
phi::DenseTensor mean_t;
phi::DenseTensor variance_t;
int begin_norm_axis_;
float eps_;
std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_;
// data on devices
float* bias_gpu_{nullptr};
float* scale_gpu_{nullptr};
half* fp16_bias_gpu_{nullptr};
half* fp16_scale_gpu_{nullptr};
};
class TransLayerNormPluginDynamicCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "trans_layernorm_plugin_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
return new TransLayerNormPluginDynamic(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(TransLayerNormPluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -31,6 +31,9 @@ if(WIN32) ...@@ -31,6 +31,9 @@ if(WIN32)
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_element_groupnorm_act_fuse_pass") "test_element_groupnorm_act_fuse_pass")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_groupnorm_act_pass_fuse_pass") list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_groupnorm_act_pass_fuse_pass")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_trans_layernorm")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_trans_layernorm")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_trt_convert_trans_layernorm")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_fused_token_prune") list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_fused_token_prune")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_fused_token_prune") list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_fused_token_prune")
endif() endif()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertTransLayernormTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def conv_filter_datagen(dics):
c = dics["c"]
x = (np.random.randn(c, c, 1, 1)) / np.sqrt(c)
return x.astype(np.float32)
def elementwise_bias_datagen(dics):
c = dics["c"]
x = np.random.random([c]) * 0.01
return x.astype(np.float32)
def layernorm_bias_datagen(dics):
c = dics["c"]
x = np.random.random([c]) * 0.1
return x.astype(np.float32)
def layernorm_scale_datagen(dics):
x = np.ones([c])
return x.astype(np.float32)
def conv2d_input_datagen(dics):
x = np.random.randn(dics["batch"], dics["c"], dics["h"], dics["w"])
x = (x - np.mean(x)) / (np.std(x))
return x.astype(np.float32)
for batch in [2]:
for begin_norm_axis in [2]:
for h in [32, 64]:
for w in [32, 64]:
for c in [128, 320, 255, 133]:
for reshape in ["flatten", "reshape"]:
dics = {
"batch": batch,
"begin_norm_axis": begin_norm_axis,
"h": h,
"w": w,
"c": c,
"flatten": {
"op_type": "flatten_contiguous_range",
"op_inputs": {
"X": ["transpose2_out"],
},
"op_outputs": {
"Out": ["reshape_out"],
},
"op_attrs": {
"start_axis": 1,
"stop_axis": 2,
},
},
"reshape": {
"op_type": "reshape2",
"op_inputs": {
"X": ["transpose2_out"],
},
"op_outputs": {
"Out": ["reshape_out"],
},
"op_attrs": {"shape": [-1, h * w, c]},
},
}
ops_config = [
{
"op_type": "conv2d",
"op_inputs": {
"Input": ["conv2d_input"],
"Filter": ["conv2d_filter"],
},
"op_outputs": {
"Output": ["conv2d_output"],
},
"op_attrs": {
"dilations": [1, 1],
"padding_algorithm": "EXPLICIT",
"groups": 1,
"paddings": [0, 0],
"strides": [1, 1],
"data_format": "NCHW",
},
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["conv2d_output"],
"Y": ["elementwise_bias"],
},
"op_outputs": {
"Out": ["elementwise_out"]
},
"op_attrs": {"axis": 1},
},
{
"op_type": "transpose2",
"op_inputs": {
"X": ["elementwise_out"],
},
"op_outputs": {
"Out": ["transpose2_out"],
},
"op_attrs": {"axis": [0, 2, 3, 1]},
},
dics[reshape],
{
"op_type": "layer_norm",
"op_inputs": {
"X": ["reshape_out"],
"Bias": ["layernorm_bias"],
"Scale": ["layernorm_scale"],
},
"op_outputs": {
"Y": ["layernorm_out"],
"Mean": ["layernorm_mean"],
"Variance": ["layernorm_variance"],
},
"op_attrs": {
"epsilon": 1e-5,
"begin_norm_axis": dics[
"begin_norm_axis"
],
},
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"conv2d_filter": TensorConfig(
data_gen=partial(
conv_filter_datagen, dics
)
),
"elementwise_bias": TensorConfig(
data_gen=partial(
elementwise_bias_datagen, dics
)
),
"layernorm_bias": TensorConfig(
data_gen=partial(
layernorm_bias_datagen, dics
)
),
"layernorm_scale": TensorConfig(
data_gen=partial(
layernorm_scale_datagen, dics
)
),
},
inputs={
"conv2d_input": TensorConfig(
data_gen=partial(
conv2d_input_datagen, dics
)
),
},
outputs=["reshape_out", "layernorm_out"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs, inputs):
conv2d_c = inputs['conv2d_input'].shape[1]
self.dynamic_shape.min_input_shape = {
"conv2d_input": [1, conv2d_c, 32, 32],
"conv2d_filter": [conv2d_c, conv2d_c, 1, 1],
"elementwise_bias": [conv2d_c],
"layernorm_bias": [conv2d_c],
"layernorm_scale": [conv2d_c],
}
self.dynamic_shape.max_input_shape = {
"conv2d_input": [4, conv2d_c, 64, 64],
"conv2d_filter": [conv2d_c, conv2d_c, 1, 1],
"elementwise_bias": [conv2d_c],
"layernorm_bias": [conv2d_c],
"layernorm_scale": [conv2d_c],
}
self.dynamic_shape.opt_input_shape = {
"conv2d_input": [4, conv2d_c, 64, 64],
"conv2d_filter": [conv2d_c, conv2d_c, 1, 1],
"elementwise_bias": [conv2d_c],
"layernorm_bias": [conv2d_c],
"layernorm_scale": [conv2d_c],
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 3
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
inputs = program_config.inputs
# just support dynamic_shape
generate_dynamic_shape(attrs, inputs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (
1e-2,
1e-2,
) # tol 1e-2 for half
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册