未验证 提交 960109af 编写于 作者: W wenbin 提交者: GitHub

Layernorm shift partition (#45736)

* first commit

* conver done

* correct format

* layernorm_shift_partition

* correct convert

* redefine plugin

* runable

* bug fix

* modify ShiftPartitionPattern

* correct

* add UT

* modify ut

* compile

* modify enforce

* modify UT
上级 b77fa1d9
......@@ -174,6 +174,7 @@ if(WITH_TENSORRT)
pass_library(set_transformer_input_convert_pass inference)
pass_library(remove_padding_recover_padding_pass inference)
pass_library(delete_remove_padding_recover_padding_pass inference)
pass_library(layernorm_shift_partition_fuse_pass inference)
endif()
if(WITH_GPU OR WITH_ROCM)
......
......@@ -3502,6 +3502,106 @@ PDNode *patterns::AddSupportInt8::operator()() {
return quant_out;
}
PDNode *patterns::LayernormShiftPartitionPattern::operator()() {
auto layer_norm_op =
pattern->NewNode(layer_norm_op_repr())
->assert_is_op("layer_norm")
->assert_more([&](Node *node) {
return node->Op()->HasAttr("begin_norm_axis") &&
(PADDLE_GET_CONST(
int, node->Op()->GetAttr("begin_norm_axis")) == 2);
});
auto layer_norm_in = pattern->NewNode(layer_norm_in_repr())
->AsInput()
->assert_is_op_input("layer_norm", "X");
auto layer_norm_bias = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_op_input("layer_norm", "Bias");
auto layer_norm_scale = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_op_input("layer_norm", "Scale");
auto layer_norm_out = pattern->NewNode(layer_norm_out_repr())
->AsIntermediate()
->assert_is_op_input("reshape2", "X")
->assert_is_op_output("layer_norm", "Y");
auto reshape1_op =
pattern->NewNode(reshape1_op_repr())
->assert_is_op("reshape2")
->assert_more([&](Node *node) {
return node->Op()->HasAttr("shape") &&
(PADDLE_GET_CONST(std::vector<int>,
node->Op()->GetAttr("shape"))
.size() == 4);
});
auto reshape1_out = pattern->NewNode(reshape1_out_repr())
->AsIntermediate()
->assert_is_op_input("reshape2", "X")
->assert_is_op_output("reshape2", "Out");
auto reshape2_op =
pattern->NewNode(reshape2_op_repr())
->assert_is_op("reshape2")
->assert_more([&](Node *node) {
return node->Op()->HasAttr("shape") &&
(PADDLE_GET_CONST(std::vector<int>,
node->Op()->GetAttr("shape"))
.size() == 6);
});
auto reshape2_out = pattern->NewNode(reshape2_out_repr())
->AsIntermediate()
->assert_is_op_input("transpose2", "X")
->assert_is_op_output("reshape2", "Out");
auto transpose_op =
pattern->NewNode(transpose_op_repr())
->assert_is_op("transpose2")
->assert_more([&](Node *node) {
if (!node->Op()->HasAttr("axis")) return false;
std::vector<int> axis =
PADDLE_GET_CONST(std::vector<int>, node->Op()->GetAttr("axis"));
if (axis.size() != 6) return false;
const std::vector<int> axis_cmp{0, 1, 3, 2, 4, 5};
return std::equal(axis.begin(), axis.end(), axis_cmp.begin());
});
auto transpose_out = pattern->NewNode(transpose_out_repr())
->AsIntermediate()
->assert_is_op_input("reshape2", "X")
->assert_is_op_output("transpose2", "Out");
auto reshape3_op =
pattern->NewNode(reshape3_op_repr())
->assert_is_op("reshape2")
->assert_more([&](Node *node) {
return node->Op()->HasAttr("shape") &&
(PADDLE_GET_CONST(std::vector<int>,
node->Op()->GetAttr("shape"))
.size() == 4);
});
auto reshape3_out = pattern->NewNode(reshape3_out_repr())
->AsIntermediate()
->assert_is_op_input("reshape2", "X")
->assert_is_op_output("reshape2", "Out");
auto reshape4_op =
pattern->NewNode(reshape4_op_repr())
->assert_is_op("reshape2")
->assert_more([&](Node *node) {
return node->Op()->HasAttr("shape") &&
(PADDLE_GET_CONST(std::vector<int>,
node->Op()->GetAttr("shape"))
.size() == 3);
});
auto reshape4_out = pattern->NewNode(reshape4_out_repr())
->assert_is_op_output("reshape2", "Out")
->AsOutput();
layer_norm_op->LinksFrom({layer_norm_in, layer_norm_bias, layer_norm_scale})
.LinksTo({layer_norm_out});
reshape1_op->LinksFrom({layer_norm_out}).LinksTo({reshape1_out});
reshape2_op->LinksFrom({reshape1_out}).LinksTo({reshape2_out});
transpose_op->LinksFrom({reshape2_out}).LinksTo({transpose_out});
reshape3_op->LinksFrom({transpose_out}).LinksTo({reshape3_out});
reshape4_op->LinksFrom({reshape3_out}).LinksTo({reshape4_out});
return reshape4_out;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -1911,6 +1911,34 @@ struct LayerNorm : public PatternBase {
PATTERN_DECL_NODE(shift_out);
};
//
// \brief Pattern looking for subgraph representing layernorm_shift_partition
// operation.
//
struct LayernormShiftPartitionPattern : public PatternBase {
LayernormShiftPartitionPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "layernorm_shift_partition") {}
PDNode* operator()();
PATTERN_DECL_NODE(layer_norm_in);
PATTERN_DECL_NODE(layer_norm_op);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(reshape1_op);
PATTERN_DECL_NODE(reshape1_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
PATTERN_DECL_NODE(reshape3_op);
PATTERN_DECL_NODE(reshape3_out);
PATTERN_DECL_NODE(reshape4_op);
PATTERN_DECL_NODE(reshape4_out);
};
// Add support int8 flag
struct AddSupportInt8 : public PatternBase {
AddSupportInt8(PDPattern* pattern, const std::string& name_scope)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.h"
#include <cmath>
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
LayerNormShiftPartitionFusePass::LayerNormShiftPartitionFusePass() {
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Variance")
.IsTensor()
.IsOptional()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumEQ(2)
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
}
void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::InvalidArgument(
"The input graph of LayerNormShiftPartitionFusePass should not be "
"nullptr."));
FusePassBase::Init(scope_name_, graph);
GraphPatternDetector gpd;
patterns::LayernormShiftPartitionPattern shift_patition_pattern(
gpd.mutable_pattern(), scope_name_);
shift_patition_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "layernorm_shift_partition_fuse in op compat failed.";
return;
}
VLOG(4) << "layernorm_shift_partition_fuse pass";
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_in, layer_norm_in, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_op, layer_norm_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias, layer_norm_bias, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_out, layer_norm_out, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape1_op, reshape1_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape1_out, reshape1_out, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_out, reshape2_out, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose_op, transpose_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
transpose_out, transpose_out, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape3_op, reshape3_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape3_out, reshape3_out, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape4_op, reshape4_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape4_out, reshape4_out, shift_patition_pattern);
std::vector<int> shape_atr1 =
PADDLE_GET_CONST(std::vector<int>, reshape1_op->Op()->GetAttr("shape"));
std::vector<int> shape_atr2 =
PADDLE_GET_CONST(std::vector<int>, reshape2_op->Op()->GetAttr("shape"));
std::vector<int> shape_atr3 =
PADDLE_GET_CONST(std::vector<int>, reshape3_op->Op()->GetAttr("shape"));
std::vector<int> shape_atr4 =
PADDLE_GET_CONST(std::vector<int>, reshape4_op->Op()->GetAttr("shape"));
// emb dim should be same
if (!((shape_atr1.back() == shape_atr2.back()) &&
(shape_atr2.back() == shape_atr3.back()) &&
(shape_atr3.back() == shape_atr4.back()))) {
return;
}
if (shape_atr1[1] != shape_atr1[2]) {
return;
}
int input_resolution = shape_atr1[1];
if (shape_atr3[1] != shape_atr3[2]) {
return;
}
int window_size = shape_atr2[2];
if (window_size < 0 || input_resolution < 0) {
return;
}
OpDesc new_op_desc;
new_op_desc.SetType("layernorm_shift_partition");
new_op_desc.SetInput("X", {layer_norm_in->Name()});
new_op_desc.SetInput("Bias", {layer_norm_bias->Name()});
new_op_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_op_desc.SetOutput("Y", {reshape4_out->Name()});
new_op_desc.SetAttr("epsilon", layer_norm_op->Op()->GetAttr("epsilon"));
new_op_desc.SetAttr("begin_norm_axis",
layer_norm_op->Op()->GetAttr("begin_norm_axis"));
new_op_desc.SetAttr("window_size", window_size);
new_op_desc.SetAttr("input_resolution", input_resolution);
new_op_desc.Flush();
auto* layernorm_shift_partition = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(layer_norm_in, layernorm_shift_partition);
IR_NODE_LINK_TO(layer_norm_bias, layernorm_shift_partition);
IR_NODE_LINK_TO(layer_norm_scale, layernorm_shift_partition);
IR_NODE_LINK_TO(layernorm_shift_partition, reshape4_out);
GraphSafeRemoveNodes(graph,
{layer_norm_op,
layer_norm_out,
reshape1_op,
reshape1_out,
reshape2_op,
reshape2_out,
transpose_op,
transpose_out,
reshape3_op,
reshape3_out,
reshape4_op});
++found_count;
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(layernorm_shift_partition_fuse_pass,
paddle::framework::ir::LayerNormShiftPartitionFusePass);
REGISTER_PASS_CAPABILITY(layernorm_shift_partition_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
// |
// layer_norm
// |
// reshape2
// |
// reshape2 |
// | fuse layernorm_shift_patition
// transpose2 -> |
// | other_op
// reshape2
// |
// reshape2
// |
// other_op
class LayerNormShiftPartitionFusePass : public FusePassBase {
public:
LayerNormShiftPartitionFusePass();
virtual ~LayerNormShiftPartitionFusePass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
private:
const std::string scope_name_{"layernorm_shift_partition_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -2184,6 +2184,7 @@ USE_TRT_CONVERTER(sum)
USE_TRT_CONVERTER(shape)
USE_TRT_CONVERTER(fill_constant)
USE_TRT_CONVERTER(fused_token_prune)
USE_TRT_CONVERTER(layernorm_shift_partition)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(sparse_fc)
USE_TRT_CONVERTER(sparse_multihead_matmul)
......
......@@ -105,6 +105,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", //
"preln_residual_bias_fuse_pass", //
"layernorm_shift_partition_fuse_pass", //
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
......
......@@ -75,7 +75,8 @@ list(
sum_op.cc
shape_op.cc
fill_constant_op.cc
fused_token_prune_op.cc)
fused_token_prune_op.cc
layernorm_shift_partition_op.cc)
if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class LayerNormShiftPartitionOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a fluid layernorm_shift_partition op to tensorrt "
"layernorm_shift_partition plugin";
framework::OpDesc op_desc(op, nullptr);
auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front());
auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front());
const int begin_norm_axis =
op_desc.HasAttr("begin_norm_axis")
? PADDLE_GET_CONST(int, op_desc.GetAttr("begin_norm_axis"))
: 1;
const float eps = op_desc.HasAttr("epsilon")
? PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"))
: 1e-5f;
const int window_size =
PADDLE_GET_CONST(int, op_desc.GetAttr("window_size"));
const int input_resolution =
PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution"));
// int shift_size = window_size / 2;
// shift_size = (input_resolution <= window_size) ? 0 : shift_size;
int shift_size = 0;
PADDLE_ENFORCE_NOT_NULL(
Bias_v,
platform::errors::InvalidArgument(
"Input(Bias) of layer_norm should not be null."));
PADDLE_ENFORCE_NOT_NULL(
Scale_v,
platform::errors::InvalidArgument(
"Input(Scale) of layer_norm should not be null."));
PADDLE_ENFORCE_EQ(
begin_norm_axis,
2,
platform::errors::InvalidArgument(
"The begin_norm_axis of LayernormShiftPartition should be %d",
begin_norm_axis));
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
auto bias_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *Bias_t);
auto scale_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Scale").front(), *Scale_t);
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
PADDLE_ENFORCE_EQ(bias_weight.get().count,
scale_weight.get().count,
platform::errors::InvalidArgument(
"The num between bias_weight and cale_weight should "
"be equal. (%d vs %d)",
bias_weight.get().count,
scale_weight.get().count));
nvinfer1::ILayer* layernorm_layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::LayernormShiftPartitionPluginDynamic* plugin =
new plugin::LayernormShiftPartitionPluginDynamic(
static_cast<const float*>(scale_weight.get().values),
static_cast<const float*>(bias_weight.get().values),
bias_weight.get().count,
shift_size,
window_size,
input_resolution,
eps,
with_fp16);
layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"LayernormShiftPartition TRT Plugin should run in dynamic shape."));
}
auto output_name = op_desc.Output("Y").front();
RreplenishLayerAndOutput(
layernorm_layer, "layernorm_shift_partition", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(layernorm_shift_partition,
LayerNormShiftPartitionOpConverter);
......@@ -176,7 +176,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"sum",
"shape",
"squeeze2",
"unsqueeze2"};
"unsqueeze2",
"layernorm_shift_partition"};
std::unordered_set<std::string> teller_set{
"mul",
"matmul",
......@@ -286,7 +287,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"shape",
"squeeze2",
"unsqueeze2",
"fused_token_prune"};
"fused_token_prune",
"layernorm_shift_partition"};
};
bool OpTeller::Tell(const framework::ir::Node* node,
......@@ -2246,6 +2248,14 @@ bool OpTeller::Tell(const framework::ir::Node* node,
#endif
}
if (op_type == "layernorm_shift_partition") {
if (!with_dynamic_shape) {
VLOG(3) << "the layernorm_shift_partition does not support "
"static shape yet";
return false;
}
}
if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
}
......
......@@ -31,7 +31,8 @@ list(
recover_padding_plugin.cu
c_allreduce_op_plugin.cu
preln_residual_bias_plugin.cu
fused_token_prune_op_plugin.cu)
fused_token_prune_op_plugin.cu
layernorm_shift_partition_op.cu)
if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND TRT_FILES spmm_plugin.cu)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/layernorm_shift_partition_op.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#define FINAL_MASK 0xffffffff
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0) shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
template <typename T>
__global__ void layernorm_shift_partition(T *out,
const T *input,
const T *gamma,
const T *beta,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps) {
int tid = threadIdx.x;
const int batch_offset = blockIdx.z * gridDim.y * gridDim.x;
const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x;
const int shifted_H_idx =
(shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y)
: blockIdx.y;
const int shifted_W_idx =
(shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x)
: blockIdx.x;
const int window_H_idx = shifted_H_idx / window_size;
const int window_W_idx = shifted_W_idx / window_size;
const int stride_of_window_H = W / window_size;
const int window_idx = window_H_idx * stride_of_window_H + window_W_idx;
const int idx_in_window = (shifted_H_idx % window_size) * window_size +
(shifted_W_idx % window_size);
const int output_bid =
batch_offset + window_idx * window_size * window_size + idx_in_window;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
float local_out =
(tid < n) ? static_cast<float>(__ldg(input + bid * n + tid)) : 0.0f;
mean = blockReduceSum<float>(local_out);
if (threadIdx.x == 0) {
s_mean = mean / n;
}
__syncthreads();
float diff = (tid < n) ? (local_out - s_mean) : 0.0f;
variance = blockReduceSum<float>(diff * diff);
if (threadIdx.x == 0) {
s_variance = variance / n + eps;
}
__syncthreads();
if (tid < n) {
out[output_bid * n + tid] =
(T)(((local_out - s_mean) * rsqrtf(s_variance)) *
static_cast<float>(__ldg(&gamma[tid])) +
static_cast<float>(__ldg(&beta[tid])));
}
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
template <>
__global__ void layernorm_shift_partition(half2 *out_ptr,
const half2 *input_ptr,
const half2 *gamma_ptr,
const half2 *beta_ptr,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps) {
const int batch_offset = blockIdx.z * gridDim.y * gridDim.x;
const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x;
const int shifted_H_idx =
(shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y)
: blockIdx.y;
const int shifted_W_idx =
(shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x)
: blockIdx.x;
const int window_H_idx = shifted_H_idx / window_size;
const int window_W_idx = shifted_W_idx / window_size;
const int stride_of_window_H = W / window_size;
const int window_idx = window_H_idx * stride_of_window_H + window_W_idx;
const int idx_in_window = (shifted_H_idx % window_size) * window_size +
(shifted_W_idx % window_size);
const int output_bid =
batch_offset + window_idx * window_size * window_size + idx_in_window;
int tid = threadIdx.x;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
float2 local_out_fp2;
float local_out = 0.0f;
int id = bid * n + tid;
if (tid < n) {
local_out_fp2 = __half22float2(__ldg(input_ptr + id));
local_out += local_out_fp2.x;
local_out += local_out_fp2.y;
}
mean = blockReduceSum<float>(local_out);
if (threadIdx.x == 0) {
s_mean = mean / (n * 2);
}
__syncthreads();
if (tid < n) {
variance = (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean);
variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean);
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / (n * 2) + eps);
}
__syncthreads();
if (tid < n) {
float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid]));
float2 beta_val = __half22float2(__ldg(&beta_ptr[tid]));
local_out_fp2.x =
(local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x;
local_out_fp2.y =
(local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y;
out_ptr[output_bid * n + tid] = __float22half2_rn(local_out_fp2);
}
}
#endif
#define kITE 4
template <typename T>
__global__ void layernorm_shift_partition_v2(T *out,
const T *__restrict input,
const T *__restrict gamma,
const T *__restrict beta,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps) {
// constexpr int kITE = 4;
const int tid = threadIdx.x;
const int batch_offset = blockIdx.z * gridDim.y * gridDim.x;
const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x;
const int shifted_H_idx =
(shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y)
: blockIdx.y;
const int shifted_W_idx =
(shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x)
: blockIdx.x;
const int window_H_idx = shifted_H_idx / window_size;
const int window_W_idx = shifted_W_idx / window_size;
const int stride_of_window_H = W / window_size;
const int window_idx = window_H_idx * stride_of_window_H + window_W_idx;
const int idx_in_window = (shifted_H_idx % window_size) * window_size +
(shifted_W_idx % window_size);
const int output_bid =
batch_offset + window_idx * window_size * window_size + idx_in_window;
const int offset = bid * n;
const int output_offset = output_bid * n;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
float local_out[kITE];
float sum = 0.0f;
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
local_out[i] = static_cast<float>(__ldg(input + offset + col_id));
sum += local_out[i];
}
}
mean = blockReduceSum<float>(sum);
if (tid == 0) {
s_mean = mean / n;
}
__syncthreads();
float var = 0.0f;
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
float diff = local_out[i] - s_mean;
local_out[i] = diff;
var += diff * diff;
}
}
variance = blockReduceSum<float>(var);
if (tid == 0) {
s_variance = rsqrtf(variance / n + eps);
}
__syncthreads();
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
out[output_offset + col_id] =
(T)(local_out[i] * s_variance *
static_cast<float>(__ldg(&gamma[col_id])) +
static_cast<float>(__ldg(&beta[col_id])));
}
}
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
template <>
__global__ void layernorm_shift_partition_v2(half2 *out_ptr,
const half2 *__restrict input_ptr,
const half2 *__restrict gamma_ptr,
const half2 *__restrict beta_ptr,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps) {
// constexpr int ite = 4;
const int tid = threadIdx.x;
const int batch_offset = blockIdx.z * gridDim.y * gridDim.x;
const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x;
const int shifted_H_idx =
(shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y)
: blockIdx.y;
const int shifted_W_idx =
(shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x)
: blockIdx.x;
const int window_H_idx = shifted_H_idx / window_size;
const int window_W_idx = shifted_W_idx / window_size;
const int stride_of_window_H = W / window_size;
const int window_idx = window_H_idx * stride_of_window_H + window_W_idx;
const int idx_in_window = (shifted_H_idx % window_size) * window_size +
(shifted_W_idx % window_size);
const int output_bid =
batch_offset + window_idx * window_size * window_size + idx_in_window;
const int offset = bid * n;
const int output_offset = output_bid * n;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
half2 local_out_half2[kITE];
const half2 zero = {static_cast<half>(0.0f), static_cast<half>(0.0f)};
// float sum = 0.0f;
half2 sum = __float2half2_rn(0.0f);
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
local_out_half2[i] = __ldg(input_ptr + offset + col_id);
sum += local_out_half2[i];
}
}
mean = blockReduceSum<float>(static_cast<float>(sum.x + sum.y));
if (threadIdx.x == 0) {
s_mean = mean / (n * 2);
}
__syncthreads();
float var = 0.0f;
half2 s_mean_2 = __float2half2_rn(s_mean);
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
local_out_half2[i] = local_out_half2[i] - s_mean_2;
float v1 = static_cast<float>(local_out_half2[i].x);
float v2 = static_cast<float>(local_out_half2[i].y);
var += v1 * v1 + v2 * v2;
}
}
variance = blockReduceSum<float>(var);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / (n * 2) + eps);
}
__syncthreads();
half2 s_var_2 = __float2half2_rn(s_variance);
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
out_ptr[output_offset + col_id] =
local_out_half2[i] * s_var_2 * __ldg(&gamma_ptr[col_id]) +
__ldg(&beta_ptr[col_id]);
}
}
}
#endif
template <typename T>
void invokeLayernormShiftPartition(T *out,
const T *input,
const T *gamma,
const T *beta,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps,
cudaStream_t stream) {
dim3 grid(W, H, batch);
int blockSize = (n + 31) / 32 * 32;
if (blockSize >= 768) {
blockSize = ((blockSize / 4) + 31) / 32 * 32;
layernorm_shift_partition_v2<T><<<grid, blockSize, 0, stream>>>(
out, input, gamma, beta, batch, H, W, n, shift_size, window_size, eps);
} else {
layernorm_shift_partition<T><<<grid, blockSize, 0, stream>>>(
out, input, gamma, beta, batch, H, W, n, shift_size, window_size, eps);
}
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
template <>
void invokeLayernormShiftPartition(half *out,
const half *input,
const half *gamma,
const half *beta,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps,
cudaStream_t stream) {
dim3 grid(W, H, batch);
int blockSize = n / 2;
blockSize = (blockSize + 31) / 32 * 32;
if ((batch * H * W >= 512 && blockSize >= 768) || blockSize > 1024) {
blockSize = ((blockSize / 4) + 31) / 32 * 32;
layernorm_shift_partition_v2<<<grid, blockSize, 0, stream>>>(
reinterpret_cast<half2 *>(out),
(const half2 *)input,
(const half2 *)gamma,
(const half2 *)beta,
batch,
H,
W,
n / 2,
shift_size,
window_size,
eps);
} else {
layernorm_shift_partition<<<grid, blockSize, 0, stream>>>(
reinterpret_cast<half2 *>(out),
(const half2 *)input,
(const half2 *)gamma,
(const half2 *)beta,
batch,
H,
W,
n / 2,
shift_size,
window_size,
eps);
}
}
#endif
template <typename T>
static void convertAndCopy(const std::vector<float> &host, T *dev) {
T *host_ptr = new T[host.size()];
std::transform(host.begin(), host.end(), host_ptr, [](float x) {
return static_cast<T>(x);
});
cudaMemcpy(dev, host_ptr, sizeof(T) * host.size(), cudaMemcpyHostToDevice);
delete host_ptr;
}
void LayernormShiftPartitionPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT {}
LayernormShiftPartitionPluginDynamic::LayernormShiftPartitionPluginDynamic(
const float *gamma,
const float *beta,
const int param_num,
int shift_size,
int window_size,
int input_resolution,
float eps,
bool with_fp16,
std::shared_ptr<void> gamma_dev,
std::shared_ptr<void> beta_dev)
: with_fp16_(with_fp16),
window_size_(window_size),
shift_size_(shift_size),
input_resolution_(input_resolution),
eps_(eps),
param_num_(param_num),
gamma_dev_(gamma_dev),
beta_dev_(beta_dev) {
beta_.resize(param_num);
gamma_.resize(param_num);
std::copy(gamma, gamma + param_num, gamma_.data());
std::copy(beta, beta + param_num, beta_.data());
int type_size = with_fp16 ? sizeof(half) : sizeof(float);
if (gamma_dev_ == nullptr) {
void *p;
cudaMalloc(reinterpret_cast<void **>(&p), param_num_ * type_size);
gamma_dev_.reset(p, [](void *ptr) { cudaFree(ptr); });
if (with_fp16)
convertAndCopy(gamma_, reinterpret_cast<half *>(p));
else
convertAndCopy(gamma_, reinterpret_cast<float *>(p));
}
if (beta_dev_ == nullptr) {
void *p;
cudaMalloc(reinterpret_cast<void **>(&p), param_num_ * type_size);
beta_dev_.reset(p, [](void *ptr) { cudaFree(ptr); });
if (with_fp16)
convertAndCopy(beta_, reinterpret_cast<half *>(p));
else
convertAndCopy(beta_, reinterpret_cast<float *>(p));
}
}
LayernormShiftPartitionPluginDynamic::LayernormShiftPartitionPluginDynamic(
void const *serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &beta_);
DeserializeValue(&serialData, &serialLength, &gamma_);
DeserializeValue(&serialData, &serialLength, &param_num_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
DeserializeValue(&serialData, &serialLength, &shift_size_);
DeserializeValue(&serialData, &serialLength, &window_size_);
DeserializeValue(&serialData, &serialLength, &input_resolution_);
DeserializeValue(&serialData, &serialLength, &eps_);
int type_size = with_fp16_ ? sizeof(half) : sizeof(float);
{
void *p;
cudaMalloc(reinterpret_cast<void **>(&p), param_num_ * type_size);
gamma_dev_.reset(p, [](void *ptr) { cudaFree(ptr); });
if (with_fp16_)
convertAndCopy(gamma_, reinterpret_cast<half *>(p));
else
convertAndCopy(gamma_, reinterpret_cast<float *>(p));
}
{
void *p;
cudaMalloc(reinterpret_cast<void **>(&p), param_num_ * type_size);
beta_dev_.reset(p, [](void *ptr) { cudaFree(ptr); });
if (with_fp16_)
convertAndCopy(beta_, reinterpret_cast<half *>(p));
else
convertAndCopy(beta_, reinterpret_cast<float *>(p));
}
}
bool LayernormShiftPartitionPluginDynamic::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc *in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out,
platform::errors::InvalidArgument("The input of LayernormShiftPartition "
"plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos,
nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos,
nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
return in.type == nvinfer1::DataType::kHALF &&
in.format == nvinfer1::TensorFormat::kLINEAR;
} else {
return in.type == nvinfer1::DataType::kFLOAT &&
in.format == nvinfer1::TensorFormat::kLINEAR;
}
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType LayernormShiftPartitionPluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(
index,
0,
platform::errors::InvalidArgument(
"The LayernormShiftPartition only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
nvinfer1::DimsExprs LayernormShiftPartitionPluginDynamic::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs *inputs,
int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(
output_index,
0,
platform::errors::InvalidArgument(
"There is only one output of the LayernormShiftPartition, "
"so the index should be zero,"
"but it's (%d)",
output_index));
PADDLE_ENFORCE_EQ(
nb_inputs,
1,
platform::errors::InvalidArgument(
"The Input of the LayernormShiftPartition should be 1, but we found "
"it has (%d) inputs",
nb_inputs));
nvinfer1::DimsExprs ret;
ret.nbDims = 3;
ret.d[0] = expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*inputs[0].d[0],
*inputs[0].d[1]),
*expr_builder.constant(window_size_ * window_size_));
ret.d[1] = expr_builder.constant(window_size_ * window_size_);
ret.d[2] = inputs[0].d[2];
return ret;
}
int LayernormShiftPartitionPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
const auto &input_dims = input_desc[0].dims;
auto input_type = input_desc[0].type;
int batch = input_dims.d[0];
int emb_dim = input_dims.d[2];
PADDLE_ENFORCE_EQ(
input_resolution_ * input_resolution_,
input_dims.d[1],
platform::errors::InvalidArgument(
"The LayernormShiftPartition‘s input_resolution is wrong (%d)",
input_dims.d[1]));
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(3) << "TRT Plugin DataType selected. LayernormShiftPartition-->fp32";
invokeLayernormShiftPartition(
reinterpret_cast<float *>(outputs[0]),
reinterpret_cast<const float *>(inputs[0]),
reinterpret_cast<const float *>(gamma_dev_.get()),
reinterpret_cast<const float *>(beta_dev_.get()),
batch,
input_resolution_,
input_resolution_,
emb_dim,
shift_size_,
window_size_,
eps_,
stream);
} else if (input_type == nvinfer1::DataType::kHALF) {
VLOG(3) << "TRT Plugin DataType selected. LayernormShiftPartition-->half";
invokeLayernormShiftPartition(
reinterpret_cast<half *>(outputs[0]),
reinterpret_cast<const half *>(inputs[0]),
reinterpret_cast<const half *>(gamma_dev_.get()),
reinterpret_cast<const half *>(beta_dev_.get()),
batch,
input_resolution_,
input_resolution_,
emb_dim,
shift_size_,
window_size_,
eps_,
stream);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The LayerNorm TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class LayernormShiftPartitionPluginDynamic : public DynamicPluginTensorRT {
public:
LayernormShiftPartitionPluginDynamic(
const float* gamma,
const float* beta,
const int param_num,
int shift_size,
int window_size,
int input_resolution,
float eps,
bool with_fp16,
std::shared_ptr<void> gamma_dev = nullptr,
std::shared_ptr<void> beta_dev = nullptr);
LayernormShiftPartitionPluginDynamic(void const* serialData,
size_t serialLength);
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new LayernormShiftPartitionPluginDynamic(gamma_.data(),
beta_.data(),
beta_.size(),
shift_size_,
window_size_,
input_resolution_,
eps_,
with_fp16_,
gamma_dev_,
beta_dev_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "layernorm_shift_partition_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override { return 0; }
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(beta_) + SerializedSize(gamma_) +
SerializedSize(param_num_) + SerializedSize(with_fp16_) +
SerializedSize(shift_size_) + SerializedSize(window_size_) +
SerializedSize(input_resolution_) + SerializedSize(eps_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, beta_);
SerializeValue(&buffer, gamma_);
SerializeValue(&buffer, param_num_);
SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, shift_size_);
SerializeValue(&buffer, window_size_);
SerializeValue(&buffer, input_resolution_);
SerializeValue(&buffer, eps_);
}
nvinfer1::DimsExprs getOutputDimensions(int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
private:
bool with_fp16_;
std::vector<float> gamma_;
std::vector<float> beta_;
int window_size_;
int shift_size_;
int input_resolution_;
int param_num_;
float eps_;
std::shared_ptr<void> gamma_dev_;
std::shared_ptr<void> beta_dev_;
};
class LayernormShiftPartitionPluginDynamicCreator
: public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "layernorm_shift_partition_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
return new LayernormShiftPartitionPluginDynamic(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(LayernormShiftPartitionPluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestLayernormShiftPartitionPass(PassAutoScanTest):
"""
|
layer_norm
|
reshape2
|
reshape2
|
transpose2
|
reshape2
|
reshape2
|
"""
def sample_predictor_configs(self, program_config):
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
config.set_trt_dynamic_shape_info({
"input_data": [1, 9, 96],
}, {
"input_data": [4, 3136, 768],
}, {
"input_data": [1, 784, 384],
})
yield config, ['layernorm_shift_partition'], (1e-5, 1e-5)
def sample_program_config(self, draw):
axis = [0, 1, 3, 2, 4, 5]
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
# begin_norm_axis has to be 2
begin_norm_axis = 2
batch_size = draw(st.integers(min_value=1, max_value=4))
window_size = draw(st.sampled_from([3, 5, 7]))
move_shape = draw(st.integers(min_value=1, max_value=8))
dim = draw(st.sampled_from([96, 192, 384, 768]))
def generate_input(attrs):
return np.random.random(
[attrs[1]["batch_size"],
*attrs[1]["input_dim"]]).astype(np.float32)
def generate_weight(attrs):
return np.random.random(attrs[1]['input_dim'][-1]).astype(
np.float32)
attrs = [{
'begin_norm_axis': begin_norm_axis,
'epsilon': epsilon,
}, {
'batch_size': batch_size,
'input_dim': [(window_size * move_shape)**2, dim],
}, {
'axis': axis,
'input_resolution': window_size * move_shape,
'move_shape': move_shape,
'window_size': window_size,
}]
layer_norm_op = OpConfig(type="layer_norm",
inputs={
"X": ["input_data"],
"Bias": ["layer_norm_bias"],
"Scale": ["layer_norm_scale"]
},
outputs={
"Y": ["layer_norm_output1"],
"Mean": ["layer_norm_output2"],
"Variance": ["layer_norm_output3"]
},
attrs={
"begin_norm_axis":
attrs[0]["begin_norm_axis"],
"epsilon": attrs[0]["epsilon"],
})
reshape_op2 = OpConfig(type="reshape2",
inputs={
"X": ["layer_norm_output1"],
},
outputs={
"Out": ["reshape_output2"],
"XShape": ["reshape_output2_xshape"],
},
attrs={
'shape': [
-1, attrs[2]["input_resolution"],
attrs[2]["input_resolution"],
attrs[1]["input_dim"][-1]
]
})
reshape_op3 = OpConfig(type="reshape2",
inputs={
"X": ["reshape_output2"],
},
outputs={
"Out": ["reshape_output3"],
"XShape": ["reshape_output3_xshape"],
},
attrs={
'shape': [
-1, attrs[2]["move_shape"],
attrs[2]["window_size"],
attrs[2]["move_shape"],
attrs[2]["window_size"],
attrs[1]["input_dim"][-1]
]
})
transpose_op4 = OpConfig(type='transpose2',
inputs={
"X": ["reshape_output3"],
},
outputs={"Out": ["transpose_output4"]},
attrs={"axis": attrs[2]['axis']})
reshape_op5 = OpConfig(type="reshape2",
inputs={
"X": ["transpose_output4"],
},
outputs={
"Out": ["reshape_output5"],
"XShape": ["reshape_output5_xshape"],
},
attrs={
'shape': [
-1, attrs[2]["window_size"],
attrs[2]["window_size"],
attrs[1]["input_dim"][-1]
]
})
reshape_op6 = OpConfig(
type="reshape2",
inputs={
"X": ["reshape_output5"],
},
outputs={
"Out": ["reshape_output6"],
"XShape": ["reshape_output6_xshape"],
},
attrs={
'shape':
[-1, attrs[2]["window_size"]**2, attrs[1]["input_dim"][-1]]
})
program_config = ProgramConfig(
ops=[
layer_norm_op, reshape_op2, reshape_op3, transpose_op4,
reshape_op5, reshape_op6
],
weights={
"layer_norm_bias":
TensorConfig(data_gen=partial(generate_weight, attrs)),
"layer_norm_scale":
TensorConfig(data_gen=partial(generate_weight, attrs))
},
inputs={
"input_data":
TensorConfig(data_gen=partial(generate_input, attrs)),
},
outputs=["reshape_output6"])
return program_config
def test(self):
self.run_and_statis(quant=False,
max_examples=20,
passes=["layernorm_shift_partition_fuse_pass"],
max_duration=250,
min_success_num=20)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册