未验证 提交 5e9f491e 编写于 作者: W Wang Bojun 提交者: GitHub

Merge layernorm trt fuse (#46320)

* first version, accuracy corrected

* disable debug print

* use blockReduceSum in phi

* add UT

* add opCompat

* code style

* code refine

* bug fix

* code refine

* test fix

* bugfix

* codesytle fix

* code style

* code-style

* code-style

* code-style
上级 b7a23adb
...@@ -121,6 +121,7 @@ if(WITH_TENSORRT) ...@@ -121,6 +121,7 @@ if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference) pass_library(trt_map_matmul_to_mul_pass inference)
pass_library(trt_multihead_matmul_fuse_pass inference) pass_library(trt_multihead_matmul_fuse_pass inference)
pass_library(trt_skip_layernorm_fuse_pass inference) pass_library(trt_skip_layernorm_fuse_pass inference)
pass_library(merge_layernorm_fuse_pass inference)
pass_library(preln_skip_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference)
pass_library(set_transformer_input_convert_pass inference) pass_library(set_transformer_input_convert_pass inference)
pass_library(remove_padding_recover_padding_pass inference) pass_library(remove_padding_recover_padding_pass inference)
......
...@@ -3638,6 +3638,92 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() { ...@@ -3638,6 +3638,92 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() {
return reshape4_out; return reshape4_out;
} }
PDNode *patterns::MergeLayernormPattern::operator()(PDNode *in) {
in->AsInput();
auto reshape2_00_op =
pattern->NewNode(reshape2_00_op_repr())->assert_is_op("reshape2");
auto reshape2_00_out = pattern->NewNode(reshape2_00_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("strided_slice", "Input")
->AsIntermediate();
auto strided_slice_10_op = pattern->NewNode(strided_slice_10_op_repr())
->assert_is_op("strided_slice");
auto strided_slice_10_out = pattern->NewNode(strided_slice_10_out_repr())
->assert_is_op_output("strided_slice", "Out")
->assert_is_op_nth_input("concat", "X", 0)
->AsIntermediate();
auto strided_slice_11_op = pattern->NewNode(strided_slice_11_op_repr())
->assert_is_op("strided_slice");
auto strided_slice_11_out = pattern->NewNode(strided_slice_11_out_repr())
->assert_is_op_output("strided_slice", "Out")
->assert_is_op_nth_input("concat", "X", 1)
->AsIntermediate();
auto strided_slice_12_op = pattern->NewNode(strided_slice_12_op_repr())
->assert_is_op("strided_slice");
auto strided_slice_12_out = pattern->NewNode(strided_slice_12_out_repr())
->assert_is_op_output("strided_slice", "Out")
->assert_is_op_nth_input("concat", "X", 2)
->AsIntermediate();
auto strided_slice_13_op = pattern->NewNode(strided_slice_13_op_repr())
->assert_is_op("strided_slice");
auto strided_slice_13_out = pattern->NewNode(strided_slice_13_out_repr())
->assert_is_op_output("strided_slice", "Out")
->assert_is_op_nth_input("concat", "X", 3)
->AsIntermediate();
auto concat_20_op = pattern->NewNode(concat_20_op_repr())
->assert_is_op("concat")
->assert_has_n_inputs(4);
auto concat_20_out = pattern->NewNode(concat_20_out_repr())
->assert_is_op_output("concat", "Out")
->assert_is_op_input("reshape2", "X")
->AsIntermediate();
auto reshape2_30_op =
pattern->NewNode(reshape2_30_op_repr())->assert_is_op("reshape2");
auto reshape2_30_out = pattern->NewNode(reshape2_30_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("layer_norm", "X")
->AsIntermediate();
auto layernorm_40_op =
pattern->NewNode(layernorm_40_op_repr())
->assert_is_op("layer_norm")
->assert_more([&](Node *node) {
return node->Op()->HasAttr("begin_norm_axis") &&
(PADDLE_GET_CONST(
int, node->Op()->GetAttr("begin_norm_axis")) == 2);
});
auto layernorm_40_in_bias = pattern->NewNode(layernorm_40_in_bias_repr())
->assert_is_op_input("layer_norm", "Bias")
->AsInput();
auto layernorm_40_in_scale = pattern->NewNode(layernorm_40_in_scale_repr())
->assert_is_op_input("layer_norm", "Scale")
->AsInput();
auto layernorm_40_out = pattern->NewNode(layernorm_40_out_repr())
->assert_is_op_output("layer_norm", "Y")
->AsOutput();
reshape2_00_op->LinksFrom({in});
reshape2_00_out->LinksFrom({reshape2_00_op});
strided_slice_10_op->LinksFrom({reshape2_00_out});
strided_slice_10_out->LinksFrom({strided_slice_10_op});
strided_slice_11_op->LinksFrom({reshape2_00_out});
strided_slice_11_out->LinksFrom({strided_slice_11_op});
strided_slice_12_op->LinksFrom({reshape2_00_out});
strided_slice_12_out->LinksFrom({strided_slice_12_op});
strided_slice_13_op->LinksFrom({reshape2_00_out});
strided_slice_13_out->LinksFrom({strided_slice_13_op});
concat_20_op->LinksFrom({strided_slice_10_out,
strided_slice_11_out,
strided_slice_12_out,
strided_slice_13_out});
concat_20_out->LinksFrom({concat_20_op});
reshape2_30_op->LinksFrom({concat_20_out});
reshape2_30_out->LinksFrom({reshape2_30_op});
layernorm_40_op->LinksFrom(
{reshape2_30_out, layernorm_40_in_bias, layernorm_40_in_scale});
layernorm_40_out->LinksFrom({layernorm_40_op});
return layernorm_40_out;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -1946,6 +1946,33 @@ struct LayernormShiftPartitionPattern : public PatternBase { ...@@ -1946,6 +1946,33 @@ struct LayernormShiftPartitionPattern : public PatternBase {
PATTERN_DECL_NODE(reshape4_out); PATTERN_DECL_NODE(reshape4_out);
}; };
// pattern for merge_layernorm
struct MergeLayernormPattern : public PatternBase {
MergeLayernormPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "merge_layernorm") {}
PDNode* operator()(PDNode* reshape2_in);
PATTERN_DECL_NODE(reshape2_00_op);
PATTERN_DECL_NODE(reshape2_00_out);
PATTERN_DECL_NODE(strided_slice_10_op);
PATTERN_DECL_NODE(strided_slice_10_out);
PATTERN_DECL_NODE(strided_slice_11_op);
PATTERN_DECL_NODE(strided_slice_11_out);
PATTERN_DECL_NODE(strided_slice_12_op);
PATTERN_DECL_NODE(strided_slice_12_out);
PATTERN_DECL_NODE(strided_slice_13_op);
PATTERN_DECL_NODE(strided_slice_13_out);
PATTERN_DECL_NODE(concat_20_op);
PATTERN_DECL_NODE(concat_20_out);
PATTERN_DECL_NODE(reshape2_30_op);
PATTERN_DECL_NODE(reshape2_30_out);
PATTERN_DECL_NODE(layernorm_40_op);
PATTERN_DECL_NODE(layernorm_40_in_bias);
PATTERN_DECL_NODE(layernorm_40_in_scale);
PATTERN_DECL_NODE(layernorm_40_out);
};
// Add support int8 flag // Add support int8 flag
struct AddSupportInt8 : public PatternBase { struct AddSupportInt8 : public PatternBase {
AddSupportInt8(PDPattern* pattern, const std::string& name_scope) AddSupportInt8(PDPattern* pattern, const std::string& name_scope)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/merge_layernorm_fuse_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(reshape2_00_op); \
GET_IR_NODE(reshape2_00_out); \
GET_IR_NODE(strided_slice_10_op); \
GET_IR_NODE(strided_slice_10_out); \
GET_IR_NODE(strided_slice_11_op); \
GET_IR_NODE(strided_slice_11_out); \
GET_IR_NODE(strided_slice_12_op); \
GET_IR_NODE(strided_slice_12_out); \
GET_IR_NODE(strided_slice_13_op); \
GET_IR_NODE(strided_slice_13_out); \
GET_IR_NODE(concat_20_op); \
GET_IR_NODE(concat_20_out); \
GET_IR_NODE(reshape2_30_op); \
GET_IR_NODE(reshape2_30_out); \
GET_IR_NODE(layernorm_40_op); \
GET_IR_NODE(layernorm_40_in_bias); \
GET_IR_NODE(layernorm_40_in_scale); \
GET_IR_NODE(layernorm_40_out);
namespace paddle {
namespace framework {
namespace ir {
MergeLayernormFusePass::MergeLayernormFusePass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("strided_slice"))
.AddInput("Input")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axes")
.IsType<std::vector<int>>()
.End()
.AddAttr("starts")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("infer_flags")
.IsType<std::vector<int>>()
.End()
.AddAttr("ends")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X")
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Variance")
.IsTensor()
.IsOptional()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumEQ(2)
.End();
}
void MergeLayernormFusePass::ApplyImpl(ir::Graph* graph) const {
GraphPatternDetector gpd;
const std::string pattern_name = "merge_layernorm";
FusePassBase::Init(pattern_name, graph);
// auto* scope = param_scope();
PDNode* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("reshape2", "X")
->AsInput();
patterns::MergeLayernormPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
int fusion_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_NODES;
OpDesc merge_layer_op_desc(reshape2_00_op->Op()->Block());
merge_layer_op_desc.SetType("merge_layernorm");
merge_layer_op_desc.SetInput("X", {subgraph.at(x)->Name()});
merge_layer_op_desc.SetInput("Bias", {layernorm_40_in_bias->Name()});
merge_layer_op_desc.SetInput("Scale", {layernorm_40_in_scale->Name()});
merge_layer_op_desc.SetOutput("Y", {layernorm_40_out->Name()});
merge_layer_op_desc.SetAttr(
"begin_norm_axis", layernorm_40_op->Op()->GetAttr("begin_norm_axis"));
merge_layer_op_desc.SetAttr("epsilon",
layernorm_40_op->Op()->GetAttr("epsilon"));
auto* merge_layer_op_node = graph->CreateOpNode(&merge_layer_op_desc);
IR_NODE_LINK_TO(subgraph.at(x), merge_layer_op_node);
IR_NODE_LINK_TO(layernorm_40_in_bias, merge_layer_op_node);
IR_NODE_LINK_TO(layernorm_40_in_scale, merge_layer_op_node);
IR_NODE_LINK_TO(merge_layer_op_node, layernorm_40_out);
GraphSafeRemoveNodes(graph,
{reshape2_00_op,
reshape2_00_out,
strided_slice_10_op,
strided_slice_10_out,
strided_slice_11_op,
strided_slice_11_out,
strided_slice_12_op,
strided_slice_12_out,
strided_slice_13_op,
strided_slice_13_out,
concat_20_op,
concat_20_out,
reshape2_30_op,
reshape2_30_out,
layernorm_40_op});
++fusion_count;
};
gpd(graph, handler);
AddStatis(fusion_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(merge_layernorm_fuse_pass,
paddle::framework::ir::MergeLayernormFusePass);
REGISTER_PASS_CAPABILITY(merge_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("reshape2", 0)
.EQ("concat", 0)
.EQ("layer_norm", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
// Fusing of path merge and layer_norm
// op: ss=stride_slice
// shape: [ss] = [?x28x28x96]
// input
// | [?x3136x96]
// reshape2 input
// | [?x56x56x96] | [?x3136x96]
// |------|------|------| merge_layernorm
// ss ss ss ss -> | [?x784x384]
// | [ss] | [ss] | [ss] | [ss] fused output
// |------|------|------|
// concat
// | [?x28x28x384]
// reshape2
// | [?x784x384]
// layer_norm
// | [?x784x384]
// output
class MergeLayernormFusePass : public FusePassBase {
public:
MergeLayernormFusePass();
virtual ~MergeLayernormFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -2260,6 +2260,7 @@ USE_TRT_CONVERTER(shape) ...@@ -2260,6 +2260,7 @@ USE_TRT_CONVERTER(shape)
USE_TRT_CONVERTER(fill_constant) USE_TRT_CONVERTER(fill_constant)
USE_TRT_CONVERTER(fused_token_prune) USE_TRT_CONVERTER(fused_token_prune)
USE_TRT_CONVERTER(layernorm_shift_partition) USE_TRT_CONVERTER(layernorm_shift_partition)
USE_TRT_CONVERTER(merge_layernorm)
USE_TRT_CONVERTER(generic_plugin_creater) USE_TRT_CONVERTER(generic_plugin_creater)
USE_TRT_CONVERTER(custom_plugin_creater) USE_TRT_CONVERTER(custom_plugin_creater)
USE_TRT_CONVERTER(lookup_table) USE_TRT_CONVERTER(lookup_table)
......
...@@ -110,6 +110,7 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -110,6 +110,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_skip_layernorm_fuse_pass", // "trt_skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", // "preln_skip_layernorm_fuse_pass", //
"layernorm_shift_partition_fuse_pass", // "layernorm_shift_partition_fuse_pass", //
"merge_layernorm_fuse_pass", //
"preln_residual_bias_fuse_pass", // "preln_residual_bias_fuse_pass", //
// "set_transformer_input_convert_pass", // // "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
......
...@@ -77,6 +77,7 @@ list( ...@@ -77,6 +77,7 @@ list(
fill_constant_op.cc fill_constant_op.cc
fused_token_prune_op.cc fused_token_prune_op.cc
layernorm_shift_partition_op.cc layernorm_shift_partition_op.cc
merge_layernorm_op.cc
generic_and_custom_plugin_creater.cc generic_and_custom_plugin_creater.cc
fused_lookup_tables_op.cc fused_lookup_tables_op.cc
expand_v2_op.cc) expand_v2_op.cc)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class MergeLayernormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a fluid merge_layernorm op to tensorrt merge_layernorm "
"plugin";
framework::OpDesc op_desc(op, nullptr);
auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front());
auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front());
const int begin_norm_axis =
op_desc.HasAttr("begin_norm_axis")
? PADDLE_GET_CONST(int, op_desc.GetAttr("begin_norm_axis"))
: 1;
const float eps = op_desc.HasAttr("epsilon")
? PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"))
: 1e-5f;
PADDLE_ENFORCE_NOT_NULL(
Bias_v,
platform::errors::InvalidArgument(
"Input(Bias) of layer_norm should not be null."));
PADDLE_ENFORCE_NOT_NULL(
Scale_v,
platform::errors::InvalidArgument(
"Input(Scale) of layer_norm should not be null."));
PADDLE_ENFORCE_EQ(
begin_norm_axis,
2,
platform::errors::InvalidArgument(
"The begin_norm_axis of LayernormShiftPartition should be %d",
begin_norm_axis));
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
auto bias_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *Bias_t);
auto scale_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Scale").front(), *Scale_t);
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
nvinfer1::ILayer* merge_layernorm_layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::MergeLayernormPluginDynamic* plugin =
new plugin::MergeLayernormPluginDynamic(
static_cast<const float*>(bias_weight.get().values),
bias_weight.get().count,
static_cast<const float*>(scale_weight.get().values),
scale_weight.get().count,
eps,
begin_norm_axis,
with_fp16);
merge_layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Currently, MergeLayernorm TRT Plugin only support dynamic shape "
"mode."));
}
auto output_name = op_desc.Output("Y").front();
RreplenishLayerAndOutput(
merge_layernorm_layer, "merge_layernorm", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(merge_layernorm, MergeLayernormOpConverter);
...@@ -2100,6 +2100,13 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2100,6 +2100,13 @@ struct SimpleOpTypeSetTeller : public Teller {
return false; return false;
} }
} }
if (op_type == "merge_layernorm") {
if (!with_dynamic_shape) {
VLOG(3) << "The merge_layernorm op does not support "
"static shape yet";
return false;
}
}
if (op_type == "lookup_table") { if (op_type == "lookup_table") {
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
...@@ -2369,6 +2376,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2369,6 +2376,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"unsqueeze2", "unsqueeze2",
"fused_token_prune", "fused_token_prune",
"layernorm_shift_partition", "layernorm_shift_partition",
"merge_layernorm",
"lookup_table", "lookup_table",
"lookup_table_v2", "lookup_table_v2",
"expand_v2"}; "expand_v2"};
......
...@@ -33,6 +33,7 @@ list( ...@@ -33,6 +33,7 @@ list(
preln_residual_bias_plugin.cu preln_residual_bias_plugin.cu
fused_token_prune_op_plugin.cu fused_token_prune_op_plugin.cu
layernorm_shift_partition_op.cu layernorm_shift_partition_op.cu
merge_layernorm_op_plugin.cu
generic_plugin.cu generic_plugin.cu
lookup_table.cu) lookup_table.cu)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#define FINAL_MASK 0xffffffff
template <typename T>
__global__ void merge_layernorm_v2(T *out,
const T *__restrict input,
const T *__restrict gamma,
const T *__restrict beta,
const float layernorm_eps,
int batch,
int H,
int W,
int n) {
// input is [batch, 2*H, 2*W, n/4]
// output is [batch, H, W, n]
// grid (W, H, batch)
// block (n)
const int kIte = 4;
const int tid = threadIdx.x;
const int W_idx = blockIdx.x;
const int H_idx = blockIdx.y;
const size_t batch_offset = blockIdx.z * H * W * n;
const int input_H_stride = W * n / 2;
const int output_H_stride = W * n;
const int n_4 = n >> 2;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
float local_out[kIte];
float sum = 0.0f;
#pragma unroll
for (int i = 0; i < kIte; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
int part_id = col_id / n_4;
int offset_in_W = part_id / 2;
int offset_in_H = part_id % 2;
size_t input_id = batch_offset +
(2 * H_idx + offset_in_H) * input_H_stride +
(2 * W_idx + offset_in_W) * n_4 + (col_id % n_4);
local_out[i] = static_cast<float>(__ldg(input + input_id));
sum += local_out[i];
}
}
mean = phi::funcs::blockReduceSum<float>(sum, FINAL_MASK);
if (tid == 0) {
s_mean = mean / n;
}
__syncthreads();
float var = 0.0f;
#pragma unroll
for (int i = 0; i < kIte; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
local_out[i] = local_out[i] - s_mean;
var += local_out[i] * local_out[i];
}
}
variance = phi::funcs::blockReduceSum<float>(var, FINAL_MASK);
if (tid == 0) {
s_variance = rsqrtf(variance / n + layernorm_eps);
}
__syncthreads();
#pragma unroll
for (int i = 0; i < kIte; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
size_t output_idx =
batch_offset + H_idx * output_H_stride + W_idx * n + col_id;
out[output_idx] =
static_cast<T>(local_out[i] * s_variance *
static_cast<float>(__ldg(&gamma[col_id])) +
static_cast<float>(__ldg(&beta[col_id])));
}
}
}
template <typename T>
void invokeMergeLayernorm(T *output,
const T *input,
const T *gamma,
const T *beta,
float layernorm_eps,
int batch,
int H,
int W,
int n,
cudaStream_t stream) {
if ((W % 2 != 0) || (H % 2 != 0)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"H(W) of merge layernorm should be a multiple of 2."));
}
dim3 grid(W / 2, H / 2, batch);
int blockSize = (n + 31) / 32 * 32;
merge_layernorm_v2<T><<<grid, blockSize, 0, stream>>>(
output, input, gamma, beta, layernorm_eps, batch, H / 2, W / 2, n * 4);
}
template void invokeMergeLayernorm<float>(float *output,
const float *input,
const float *gamma,
const float *beta,
float layernorm_eps,
int batch,
int H,
int W,
int n,
cudaStream_t stream);
template void invokeMergeLayernorm<half>(half *output,
const half *input,
const half *gamma,
const half *beta,
float layernorm_eps,
int batch,
int H,
int W,
int n,
cudaStream_t stream);
template <typename T>
static void convertAndCopy(const std::vector<float> &host, T *dev) {
T *host_ptr = new T[host.size()];
std::transform(host.begin(), host.end(), host_ptr, [](float x) {
return static_cast<T>(x);
});
cudaMemcpy(dev, host_ptr, sizeof(T) * host.size(), cudaMemcpyHostToDevice);
delete host_ptr;
}
void MergeLayernormPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT {}
MergeLayernormPluginDynamic::MergeLayernormPluginDynamic(
const float *bias_d,
const size_t bias_num,
const float *scale_d,
const size_t scale_num,
const float eps,
const int begin_norm_axis,
const bool with_fp16,
std::shared_ptr<void> bias_device,
std::shared_ptr<void> scale_device)
: eps_(eps),
begin_norm_axis_(begin_norm_axis),
with_fp16_(with_fp16),
bias_device_(bias_device),
scale_device_(scale_device) {
bias_.resize(bias_num);
scale_.resize(scale_num);
std::copy(bias_d, bias_d + bias_num, bias_.data());
std::copy(scale_d, scale_d + scale_num, scale_.data());
int type_size = with_fp16_ ? sizeof(half) : sizeof(float);
if (bias_device_ == nullptr) {
void *p;
cudaMalloc(&p, bias_num * type_size);
bias_device_.reset(p, [](void *ptr) { cudaFree(ptr); });
if (with_fp16) {
convertAndCopy<half>(bias_, reinterpret_cast<half *>(p));
} else {
convertAndCopy<float>(bias_, reinterpret_cast<float *>(p));
}
}
if (scale_device_ == nullptr) {
void *p;
cudaMalloc(&p, scale_num * type_size);
scale_device_.reset(p, [](void *ptr) { cudaFree(ptr); });
if (with_fp16) {
convertAndCopy<half>(scale_, reinterpret_cast<half *>(p));
} else {
convertAndCopy<float>(scale_, reinterpret_cast<float *>(p));
}
}
}
bool MergeLayernormPluginDynamic::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc *in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out,
platform::errors::InvalidArgument("The input of MergeLayernorm "
"plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos,
nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos,
nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
return in.type == nvinfer1::DataType::kHALF &&
in.format == nvinfer1::TensorFormat::kLINEAR;
} else {
return in.type == nvinfer1::DataType::kFLOAT &&
in.format == nvinfer1::TensorFormat::kLINEAR;
}
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType MergeLayernormPluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(index,
0,
platform::errors::InvalidArgument(
"The MergeLayernorm only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
nvinfer1::DimsExprs MergeLayernormPluginDynamic::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs *inputs,
int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(output_index,
0,
platform::errors::InvalidArgument(
"There is only one output of the MergeLayernorm, "
"so the index should be zero,"
"but it's (%d)",
output_index));
PADDLE_ENFORCE_EQ(
nb_inputs,
1,
platform::errors::InvalidArgument(
"The Input of the MergeLayernorm should be 1, but we found "
"it has (%d) inputs",
nb_inputs));
nvinfer1::DimsExprs ret;
ret.nbDims = 3;
ret.d[0] = inputs[0].d[0];
ret.d[1] = expr_builder.operation(nvinfer1::DimensionOperation::kFLOOR_DIV,
*inputs[0].d[1],
*expr_builder.constant(4));
ret.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*inputs[0].d[2],
*expr_builder.constant(4));
return ret;
}
int MergeLayernormPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
const auto &input_dims = input_desc[0].dims;
auto input_type = input_desc[0].type;
int batch = input_dims.d[0];
int input_resolution = static_cast<int>(std::sqrt(input_dims.d[1]));
int dim = static_cast<int>(input_dims.d[2]);
PADDLE_ENFORCE_EQ(
input_resolution * input_resolution,
input_dims.d[1],
platform::errors::InvalidArgument(
"The MergeLayernorm TRT Plugin get invalid input_resolution %d",
input_resolution));
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(3) << "TRT Plugin DataType selected. MergeLayernorm-->fp32";
invokeMergeLayernorm<float>(
reinterpret_cast<float *>(outputs[0]),
reinterpret_cast<const float *>(inputs[0]),
reinterpret_cast<const float *>(scale_device_.get()),
reinterpret_cast<const float *>(bias_device_.get()),
eps_,
batch,
input_resolution,
input_resolution,
dim,
stream);
} else if (input_type == nvinfer1::DataType::kHALF) {
VLOG(3) << "TRT Plugin DataType selected. MergeLayernorm-->fp16";
invokeMergeLayernorm<half>(
reinterpret_cast<half *>(outputs[0]),
reinterpret_cast<const half *>(inputs[0]),
reinterpret_cast<const half *>(scale_device_.get()),
reinterpret_cast<const half *>(bias_device_.get()),
eps_,
batch,
input_resolution,
input_resolution,
dim,
stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The MergeLayernorm TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class MergeLayernormPluginDynamic : public DynamicPluginTensorRT {
public:
MergeLayernormPluginDynamic(const float* bias_d,
const size_t bias_num,
const float* scale_d,
const size_t scale_num,
const float eps,
const int begin_norm_axis,
const bool with_fp16,
std::shared_ptr<void> bias_device = nullptr,
std::shared_ptr<void> scale_device = nullptr);
MergeLayernormPluginDynamic(void const* serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &bias_);
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &begin_norm_axis_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new MergeLayernormPluginDynamic(bias_.data(),
bias_.size(),
scale_.data(),
scale_.size(),
eps_,
begin_norm_axis_,
with_fp16_,
bias_device_,
scale_device_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "merge_layernorm_plugin_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override { return 0; }
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(bias_) + SerializedSize(scale_) +
SerializedSize(eps_) + SerializedSize(begin_norm_axis_) +
SerializedSize(with_fp16_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, bias_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, begin_norm_axis_);
SerializeValue(&buffer, with_fp16_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
private:
std::vector<float> bias_;
std::vector<float> scale_;
float eps_;
int begin_norm_axis_;
bool with_fp16_;
std::shared_ptr<void> bias_device_ = nullptr;
std::shared_ptr<void> scale_device_ = nullptr;
};
class MergeLayernormPluginDynamicCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "merge_layernorm_plugin_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
return new MergeLayernormPluginDynamic(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(MergeLayernormPluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -122,6 +122,7 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -122,6 +122,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
#set_tests_properties(test_trt_multiclass_nms_op PROPERTIES TIMEOUT 200) #set_tests_properties(test_trt_multiclass_nms_op PROPERTIES TIMEOUT 200)
set_tests_properties(test_trt_dynamic_shape PROPERTIES TIMEOUT 120) set_tests_properties(test_trt_dynamic_shape PROPERTIES TIMEOUT 120)
set_tests_properties(test_trt_inspector PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_inspector PROPERTIES TIMEOUT 60)
set_tests_properties(test_merge_layernorm_fuse_pass PROPERTIES TIMEOUT 180)
if(WITH_NV_JETSON) if(WITH_NV_JETSON)
set_tests_properties( set_tests_properties(
test_trt_pool_op test_trt_pool_op
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
import unittest
import hypothesis.strategies as st
class TestMergeLayernormFusePass(PassAutoScanTest):
# input
# | [?x3136x96]
# reshape2 input
# | [?x56x56x96] | [?x3136x96]
# |--------------|--------------|--------------| merge_layernorm
# strided_slice strided_slice strided_slice strided_slice -> | [?x784x384]
# | [?x28x28x96] | [?x28x28x96] | [?x28x28x96] | fused output
# |--------------|--------------|--------------|
# concat
# | [?x28x28x384]
# reshape2
# | [?x784x384]
# layer_norm
# | [?x784x384]
# output
def sample_predictor_configs(self, program_config):
# trt dynamic_shape fp32
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=1 << 20,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
config.set_trt_dynamic_shape_info({"input_data": [1, 196, 96]},
{"input_data": [4, 3136, 384]},
{"input_data": [1, 3136, 96]})
yield config, ["merge_layernorm"], (1e-5, 1e-5)
# trt dynamic_shape fp16
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=1 << 20,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False)
config.set_trt_dynamic_shape_info({"input_data": [1, 196, 96]},
{"input_data": [4, 3136, 384]},
{"input_data": [1, 3136, 96]})
yield config, ["merge_layernorm"], (1e-3, 1e-3)
def sample_program_config(self, draw):
batch_size = draw(st.integers(min_value=1, max_value=4))
input_H_W = draw(st.sampled_from([56, 28, 14]))
input_n = draw(st.sampled_from([96, 192, 384]))
layernorm_40_begin_norm_axis = 2
layernorm_40_epsilon = draw(
st.floats(min_value=0.0000001, max_value=0.001))
def generate_input(attrs):
return np.random.random([
attrs[3]['batch_size'],
attrs[3]['input_H_W'] * attrs[3]['input_H_W'],
attrs[3]['input_n']
]).astype(np.float32)
def generate_weight(attrs):
return np.random.random([attrs[3]['input_n'] * 4
]).astype(np.float32)
attrs = [{
'shape': [-1, input_H_W, input_H_W, input_n]
}, {
'shape': [-1, int(input_H_W * input_H_W / 4),
int(input_n * 4)]
}, {
'begin_norm_axis': layernorm_40_begin_norm_axis,
'epsilon': layernorm_40_epsilon
}, {
'batch_size': batch_size,
'input_H_W': input_H_W,
'input_n': input_n
}]
reshape2_00_op = OpConfig(type="reshape2",
inputs={'X': ['input_data']},
outputs={
'Out': ['reshape2_00_out'],
'XShape': ['reshape2_00_outxshape']
},
attrs={'shape': attrs[0]['shape']})
strided_slice_10_op = OpConfig(
type="strided_slice",
inputs={'Input': ['reshape2_00_out']},
outputs={'Out': ['strided_slice_10_out']},
attrs={
'axes': [1, 2],
'starts': [0, 0],
'infer_flags': [1, 1],
'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']],
'strides': [2, 2]
})
strided_slice_11_op = OpConfig(
type="strided_slice",
inputs={'Input': ['reshape2_00_out']},
outputs={'Out': ['strided_slice_11_out']},
attrs={
'axes': [1, 2],
'starts': [1, 0],
'infer_flags': [1, 1],
'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']],
'strides': [2, 2]
})
strided_slice_12_op = OpConfig(
type="strided_slice",
inputs={'Input': ['reshape2_00_out']},
outputs={'Out': ['strided_slice_12_out']},
attrs={
'axes': [1, 2],
'starts': [0, 1],
'infer_flags': [1, 1],
'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']],
'strides': [2, 2]
})
strided_slice_13_op = OpConfig(
type="strided_slice",
inputs={'Input': ['reshape2_00_out']},
outputs={'Out': ['strided_slice_13_out']},
attrs={
'axes': [1, 2],
'starts': [1, 1],
'infer_flags': [1, 1],
'ends': [attrs[3]['input_H_W'], attrs[3]['input_H_W']],
'strides': [2, 2]
})
concat_20_op = OpConfig(type="concat",
inputs={
'X': [
'strided_slice_10_out',
'strided_slice_11_out',
'strided_slice_12_out',
'strided_slice_13_out'
]
},
outputs={'Out': ['concat_20_out']},
attrs={'axis': -1})
reshape2_30_op = OpConfig(type='reshape2',
inputs={'X': ['concat_20_out']},
outputs={
'Out': ['reshape2_30_Out'],
'XShape': ['reshape2_30_XShape']
},
attrs={'shape': attrs[1]['shape']})
layernorm_40_op = OpConfig(type='layer_norm',
inputs={
'X': ['reshape2_30_Out'],
'Bias': ['layer_norm_bias'],
'Scale': ['layer_norm_scale']
},
outputs={
"Y": ["layer_norm_out"],
"Mean": ["layer_norm_outMean"],
"Variance": ["layer_norm_outVariance"]
},
attrs={
'begin_norm_axis':
attrs[2]['begin_norm_axis'],
'epsilon':
attrs[2]['epsilon']
})
program_config = ProgramConfig(
ops=[
reshape2_00_op, strided_slice_10_op, strided_slice_11_op,
strided_slice_12_op, strided_slice_13_op, concat_20_op,
reshape2_30_op, layernorm_40_op
],
weights={
'layer_norm_bias':
TensorConfig(data_gen=partial(generate_weight, attrs)),
'layer_norm_scale':
TensorConfig(data_gen=partial(generate_weight, attrs))
},
inputs={
'input_data':
TensorConfig(data_gen=partial(generate_input, attrs))
},
outputs=['layer_norm_out'])
return program_config
def test(self):
self.run_and_statis(quant=False,
max_examples=50,
passes=["merge_layernorm_fuse_pass"],
max_duration=250,
min_success_num=50)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册