未验证 提交 d17d0cd1 编写于 作者: W wenbin 提交者: GitHub

Preln_Layernorm_Shift_Partition (#47099)

* prelnlayernorm_shift

* add ut

* remove paddle_enforce

* remove useless

* add UT

* remove UT

* add UT

* set timeout
上级 c1c2be2d
......@@ -129,6 +129,7 @@ if(WITH_TENSORRT)
pass_library(remove_padding_recover_padding_pass inference)
pass_library(delete_remove_padding_recover_padding_pass inference)
pass_library(layernorm_shift_partition_fuse_pass inference)
pass_library(preln_layernorm_x_fuse_pass inference)
endif()
if(WITH_TENSORRT AND NOT WIN32)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct PrelnLayerNormX : public PatternBase {
PrelnLayerNormX(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_layernorm_x") {}
void operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(elementwise_bias);
PATTERN_DECL_NODE(elementwise0);
PATTERN_DECL_NODE(elementwise1);
PATTERN_DECL_NODE(layer_norm);
// declare variable node's name
PATTERN_DECL_NODE(elementwise0_out);
PATTERN_DECL_NODE(elementwise1_out);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
};
void PrelnLayerNormX::operator()(PDNode *x, PDNode *y) {
auto *elementwise1 =
pattern->NewNode(elementwise1_repr())->assert_is_op("elementwise_add");
auto *elementwise1_out_var =
pattern->NewNode(elementwise1_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("layernorm_shift_partition", "X");
elementwise1->LinksFrom({x, y}).LinksTo({elementwise1_out_var});
// Create nodes for layer_norm op.
auto *layer_norm = pattern->NewNode(layer_norm_repr())
->assert_is_op("layernorm_shift_partition");
auto *layer_norm_bias_var =
pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layernorm_shift_partition", "Bias");
auto *layer_norm_scale_var =
pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layernorm_shift_partition", "Scale");
auto *layer_norm_out_var =
pattern->NewNode(layer_norm_out_repr())
->AsOutput()
->assert_is_op_output("layernorm_shift_partition", "Y");
// Add links for layer_norm op.
layer_norm
->LinksFrom(
{elementwise1_out_var, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo({layer_norm_out_var});
}
} // namespace patterns
int PrelnLayerNormXFusePass::ApplyPattern(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_layernorm_x_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
PDNode *x = nullptr;
PDNode *y = nullptr;
x = gpd.mutable_pattern()
->NewNode("preln_layernorm_x_fuse/x")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "X");
y = gpd.mutable_pattern()
->NewNode("preln_layernorm_x_fuse/y")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "Y");
patterns::PrelnLayerNormX fused_pattern(gpd.mutable_pattern(),
"preln_layernorm_x_fuse");
fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
VLOG(4) << "handle preln layernorm x fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise1, elementwise1, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise1_out, elementwise1_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "preln_layernorm_x_fuse pass in op compat failed.";
return;
}
static int cnt = 0;
if (cnt++ > 0) {
// return;
}
std::unordered_set<const Node *> del_node_set;
// Create an PrelnLayerNormX op node
OpDesc new_desc(*layer_norm->Op());
new_desc.SetType("preln_layernorm_shift_partition");
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetOutput("Out_0", {elementwise1_out->Name()});
new_desc.SetOutput("Out_1", {layer_norm_out->Name()});
new_desc.RemoveOutput("Y");
new_desc.Flush();
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(elementwise1);
del_node_set.insert(layer_norm);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(layer_norm_scale, fused_node);
IR_NODE_LINK_TO(layer_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, layer_norm_out);
IR_NODE_LINK_TO(fused_node, elementwise1_out);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void PrelnLayerNormXFusePass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init("preln_layernorm_x_fuse", graph);
int found_subgraph_count = ApplyPattern(graph);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(preln_layernorm_x_fuse_pass,
paddle::framework::ir::PrelnLayerNormXFusePass);
REGISTER_PASS_CAPABILITY(preln_layernorm_x_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"elementwise_add", 1));
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
//
// | | | |
// other_op1 other_op2 other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> preln_layernorm_shift_partition
// | | | |
// other_op4 layernorm_shift_partition other_op4 other_op3
// |
// other_op3
class Graph;
class PrelnLayerNormXFusePass : public FusePassBase {
public:
PrelnLayerNormXFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1, 2})
.End();
}
virtual ~PrelnLayerNormXFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
int ApplyPattern(ir::Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -2271,6 +2271,7 @@ USE_TRT_CONVERTER(shape)
USE_TRT_CONVERTER(fill_constant)
USE_TRT_CONVERTER(fused_token_prune)
USE_TRT_CONVERTER(layernorm_shift_partition)
USE_TRT_CONVERTER(preln_layernorm_shift_partition)
USE_TRT_CONVERTER(merge_layernorm)
USE_TRT_CONVERTER(generic_plugin_creater)
USE_TRT_CONVERTER(custom_plugin_creater)
......
......@@ -113,6 +113,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"layernorm_shift_partition_fuse_pass", //
"merge_layernorm_fuse_pass", //
"preln_residual_bias_fuse_pass", //
"preln_layernorm_x_fuse_pass", //
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
......
......@@ -77,6 +77,7 @@ list(
fill_constant_op.cc
fused_token_prune_op.cc
layernorm_shift_partition_op.cc
preln_layernorm_shift_partition_op.cc
merge_layernorm_op.cc
generic_and_custom_plugin_creater.cc
fused_lookup_tables_op.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class PrelnLayerNormShiftPartitionOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a fluid preln_layernorm_shift_partition op to tensorrt "
"preln_layernorm_shift_partition plugin";
framework::OpDesc op_desc(op, nullptr);
auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Y = engine_->GetITensor(op_desc.Input("Y").front());
std::vector<nvinfer1::ITensor*> inputs{X, Y};
auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front());
auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front());
const float eps = op_desc.HasAttr("epsilon")
? PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"))
: 1e-5f;
const int window_size =
PADDLE_GET_CONST(int, op_desc.GetAttr("window_size"));
const int input_resolution =
PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution"));
const int shift_size =
op_desc.HasAttr("shift_size")
? PADDLE_GET_CONST(int, op_desc.GetAttr("shift_size"))
: 0;
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
auto bias_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *Bias_t);
auto scale_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Scale").front(), *Scale_t);
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
nvinfer1::ILayer* layernorm_layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::PrelnLnormShiftPartitionPluginDynamic* plugin =
new plugin::PrelnLnormShiftPartitionPluginDynamic(
static_cast<const float*>(scale_weight.get().values),
static_cast<const float*>(bias_weight.get().values),
bias_weight.get().count,
shift_size,
window_size,
input_resolution,
eps,
with_fp16);
layernorm_layer =
engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin);
}
std::vector<std::string> output_names;
output_names.emplace_back(op_desc.Output("Out_0").front());
output_names.emplace_back(op_desc.Output("Out_1").front());
RreplenishLayerAndOutput(layernorm_layer,
"preln_layernorm_shift_partition",
output_names,
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(preln_layernorm_shift_partition,
PrelnLayerNormShiftPartitionOpConverter);
......@@ -2100,6 +2100,15 @@ struct SimpleOpTypeSetTeller : public Teller {
return false;
}
}
if (op_type == "preln_layernorm_shift_partition") {
if (!with_dynamic_shape) {
VLOG(3) << "the layernorm_shift_partition does not support "
"static shape yet";
return false;
}
}
if (op_type == "merge_layernorm") {
if (!with_dynamic_shape) {
VLOG(3) << "The merge_layernorm op does not support "
......@@ -2259,9 +2268,11 @@ struct SimpleOpTypeSetTeller : public Teller {
"squeeze2",
"unsqueeze2",
"layernorm_shift_partition",
"preln_layernorm_shift_partition",
"lookup_table",
"lookup_table_v2",
"expand_v2"};
std::unordered_set<std::string> teller_set{
"mul",
"matmul",
......@@ -2376,6 +2387,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"unsqueeze2",
"fused_token_prune",
"layernorm_shift_partition",
"preln_layernorm_shift_partition",
"merge_layernorm",
"lookup_table",
"lookup_table_v2",
......
......@@ -33,9 +33,11 @@ list(
preln_residual_bias_plugin.cu
fused_token_prune_op_plugin.cu
layernorm_shift_partition_op.cu
prelnlayernorm_shift_partition_op.cu
merge_layernorm_op_plugin.cu
generic_plugin.cu
lookup_table.cu)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32)
list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu
many_emb_Layernorm_varseqlen_kernelMTron.cu
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#define FINAL_MASK 0xffffffff
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0) shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
template <typename T>
__global__ void preln_layernorm_shift_partition(T *out0,
T *out1,
const T *input0,
const T *input1,
const T *gamma,
const T *beta,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps) {
int tid = threadIdx.x;
const int batch_offset = blockIdx.z * gridDim.y * gridDim.x;
const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x;
const int shifted_H_idx =
(shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y)
: blockIdx.y;
const int shifted_W_idx =
(shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x)
: blockIdx.x;
const int window_H_idx = shifted_H_idx / window_size;
const int window_W_idx = shifted_W_idx / window_size;
const int stride_of_window_H = W / window_size;
const int window_idx = window_H_idx * stride_of_window_H + window_W_idx;
const int idx_in_window = (shifted_H_idx % window_size) * window_size +
(shifted_W_idx % window_size);
const int output_bid =
batch_offset + window_idx * window_size * window_size + idx_in_window;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
const int index = bid * n + tid;
float local_out = 0;
if (tid < n) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
local_out = static_cast<float>(__ldg(input0 + index));
#else
local_out = static_cast<float>(input0[index]);
#endif
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
local_out += static_cast<float>(__ldg(input1 + index));
#else
local_out += static_cast<float>(input1[index]);
#endif
out0[index] = local_out;
}
mean = blockReduceSum<float>(local_out);
if (threadIdx.x == 0) {
s_mean = mean / n;
}
__syncthreads();
float diff = (tid < n) ? (local_out - s_mean) : 0.0f;
variance = blockReduceSum<float>(diff * diff);
if (threadIdx.x == 0) {
s_variance = variance / n + eps;
}
__syncthreads();
if (tid < n) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
out1[output_bid * n + tid] =
(T)(((local_out - s_mean) * rsqrtf(s_variance)) *
static_cast<float>(__ldg(&gamma[tid])) +
static_cast<float>(__ldg(&beta[tid])));
#else
out1[output_bid * n + tid] =
(T)(((local_out - s_mean) * rsqrtf(s_variance)) *
static_cast<float>(gamma[tid]) +
static_cast<float>(beta[tid]));
#endif
}
}
template <>
__global__ void preln_layernorm_shift_partition(half2 *out0_ptr,
half2 *out1_ptr,
const half2 *input0_ptr,
const half2 *input1_ptr,
const half2 *gamma_ptr,
const half2 *beta_ptr,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const int batch_offset = blockIdx.z * gridDim.y * gridDim.x;
const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x;
const int shifted_H_idx =
(shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y)
: blockIdx.y;
const int shifted_W_idx =
(shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x)
: blockIdx.x;
const int window_H_idx = shifted_H_idx / window_size;
const int window_W_idx = shifted_W_idx / window_size;
const int stride_of_window_H = W / window_size;
const int window_idx = window_H_idx * stride_of_window_H + window_W_idx;
const int idx_in_window = (shifted_H_idx % window_size) * window_size +
(shifted_W_idx % window_size);
const int output_bid =
batch_offset + window_idx * window_size * window_size + idx_in_window;
int tid = threadIdx.x;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
float2 local_out_fp2;
float local_out = 0.0f;
int id = bid * n + tid;
if (tid < n) {
half2 tmp = __hadd2(__ldg(input0_ptr + id), __ldg(input1_ptr + id));
local_out_fp2 = __half22float2(tmp);
local_out += local_out_fp2.x;
local_out += local_out_fp2.y;
out0_ptr[id] = tmp;
}
mean = blockReduceSum<float>(local_out);
if (threadIdx.x == 0) {
s_mean = mean / (n * 2);
}
__syncthreads();
if (tid < n) {
variance = (local_out_fp2.x - s_mean) * (local_out_fp2.x - s_mean);
variance += (local_out_fp2.y - s_mean) * (local_out_fp2.y - s_mean);
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / (n * 2) + eps);
}
__syncthreads();
if (tid < n) {
float2 gamma_val = __half22float2(__ldg(&gamma_ptr[tid]));
float2 beta_val = __half22float2(__ldg(&beta_ptr[tid]));
local_out_fp2.x =
(local_out_fp2.x - s_mean) * s_variance * gamma_val.x + beta_val.x;
local_out_fp2.y =
(local_out_fp2.y - s_mean) * s_variance * gamma_val.y + beta_val.y;
out1_ptr[output_bid * n + tid] = __float22half2_rn(local_out_fp2);
}
#endif
}
#define kITE 4
template <typename T>
__global__ void preln_layernorm_shift_partition_v2(T *out0,
T *out1,
const T *__restrict input0,
const T *__restrict input1,
const T *__restrict gamma,
const T *__restrict beta,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps) {
// constexpr int kITE = 4;
const int tid = threadIdx.x;
const int batch_offset = blockIdx.z * gridDim.y * gridDim.x;
const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x;
const int shifted_H_idx =
(shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y)
: blockIdx.y;
const int shifted_W_idx =
(shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x)
: blockIdx.x;
const int window_H_idx = shifted_H_idx / window_size;
const int window_W_idx = shifted_W_idx / window_size;
const int stride_of_window_H = W / window_size;
const int window_idx = window_H_idx * stride_of_window_H + window_W_idx;
const int idx_in_window = (shifted_H_idx % window_size) * window_size +
(shifted_W_idx % window_size);
const int output_bid =
batch_offset + window_idx * window_size * window_size + idx_in_window;
const int offset = bid * n;
const int output_offset = output_bid * n;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
float local_out[kITE];
float sum = 0.0f;
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
int index = offset + col_id;
if (col_id < n) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
local_out[i] = static_cast<float>(__ldg(input0 + index));
#else
local_out[i] = static_cast<float>(input0[index]);
#endif
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
local_out[i] += static_cast<float>(__ldg(input1 + index));
#else
local_out[i] += static_cast<float>(input1[index]);
#endif
out0[index] = local_out[i];
sum += local_out[i];
}
}
mean = blockReduceSum<float>(sum);
if (tid == 0) {
s_mean = mean / n;
}
__syncthreads();
float var = 0.0f;
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
float diff = local_out[i] - s_mean;
local_out[i] = diff;
var += diff * diff;
}
}
variance = blockReduceSum<float>(var);
if (tid == 0) {
s_variance = rsqrtf(variance / n + eps);
}
__syncthreads();
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
out1[output_offset + col_id] =
(T)(local_out[i] * s_variance *
static_cast<float>(__ldg(&gamma[col_id])) +
static_cast<float>(__ldg(&beta[col_id])));
#else
out1[output_offset + col_id] =
(T)(local_out[i] * s_variance * static_cast<float>(gamma[col_id]) +
static_cast<float>(beta[col_id]));
#endif
}
}
}
template <>
__global__ void preln_layernorm_shift_partition_v2(
half2 *out0_ptr,
half2 *out1_ptr,
const half2 *__restrict input0_ptr,
const half2 *__restrict input1_ptr,
const half2 *__restrict gamma_ptr,
const half2 *__restrict beta_ptr,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
// constexpr int ite = 4;
const int tid = threadIdx.x;
const int batch_offset = blockIdx.z * gridDim.y * gridDim.x;
const int bid = batch_offset + blockIdx.y * gridDim.x + blockIdx.x;
const int shifted_H_idx =
(shift_size != 0) ? ((blockIdx.y - shift_size + gridDim.y) % gridDim.y)
: blockIdx.y;
const int shifted_W_idx =
(shift_size != 0) ? ((blockIdx.x - shift_size + gridDim.x) % gridDim.x)
: blockIdx.x;
const int window_H_idx = shifted_H_idx / window_size;
const int window_W_idx = shifted_W_idx / window_size;
const int stride_of_window_H = W / window_size;
const int window_idx = window_H_idx * stride_of_window_H + window_W_idx;
const int idx_in_window = (shifted_H_idx % window_size) * window_size +
(shifted_W_idx % window_size);
const int output_bid =
batch_offset + window_idx * window_size * window_size + idx_in_window;
const int offset = bid * n;
const int output_offset = output_bid * n;
__shared__ float s_mean;
__shared__ float s_variance;
float mean = 0.0f;
float variance = 0.0f;
half2 local_out_half2[kITE];
// float sum = 0.0f;
half2 sum = __float2half2_rn(0.0f);
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
int index = offset + col_id;
local_out_half2[i] = __ldg(input0_ptr + index);
local_out_half2[i] =
__hadd2(local_out_half2[i], __ldg(input1_ptr + index));
out0_ptr[i] = local_out_half2[i];
sum += local_out_half2[i];
}
}
mean = blockReduceSum<float>(static_cast<float>(sum.x + sum.y));
if (threadIdx.x == 0) {
s_mean = mean / (n * 2);
}
__syncthreads();
float var = 0.0f;
half2 s_mean_2 = __float2half2_rn(s_mean);
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
local_out_half2[i] = local_out_half2[i] - s_mean_2;
float v1 = static_cast<float>(local_out_half2[i].x);
float v2 = static_cast<float>(local_out_half2[i].y);
var += v1 * v1 + v2 * v2;
}
}
variance = blockReduceSum<float>(var);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / (n * 2) + eps);
}
__syncthreads();
half2 s_var_2 = __float2half2_rn(s_variance);
#pragma unroll
for (int i = 0; i < kITE; i++) {
int col_id = i * blockDim.x + tid;
if (col_id < n) {
out1_ptr[output_offset + col_id] =
local_out_half2[i] * s_var_2 * __ldg(&gamma_ptr[col_id]) +
__ldg(&beta_ptr[col_id]);
}
}
#endif
}
template <typename T>
void invokePrelnLayernormShiftPartition(T *out0,
T *out1,
const T *input0,
const T *input1,
const T *gamma,
const T *beta,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps,
cudaStream_t stream) {
dim3 grid(W, H, batch);
int blockSize = (n + 31) / 32 * 32;
if (blockSize >= 768) {
blockSize = ((blockSize / 4) + 31) / 32 * 32;
preln_layernorm_shift_partition_v2<T>
<<<grid, blockSize, 0, stream>>>(out0,
out1,
input0,
input1,
gamma,
beta,
batch,
H,
W,
n,
shift_size,
window_size,
eps);
} else {
preln_layernorm_shift_partition<T>
<<<grid, blockSize, 0, stream>>>(out0,
out1,
input0,
input1,
gamma,
beta,
batch,
H,
W,
n,
shift_size,
window_size,
eps);
}
}
template <>
void invokePrelnLayernormShiftPartition(half *out0,
half *out1,
const half *input0,
const half *input1,
const half *gamma,
const half *beta,
int batch,
int H,
int W,
int n,
int shift_size,
int window_size,
const float eps,
cudaStream_t stream) {
dim3 grid(W, H, batch);
int blockSize = n / 2;
blockSize = (blockSize + 31) / 32 * 32;
if ((batch * H * W >= 512 && blockSize >= 768) || blockSize > 1024) {
blockSize = ((blockSize / 4) + 31) / 32 * 32;
preln_layernorm_shift_partition_v2<<<grid, blockSize, 0, stream>>>(
reinterpret_cast<half2 *>(out0),
reinterpret_cast<half2 *>(out1),
(const half2 *)input0,
(const half2 *)input1,
(const half2 *)gamma,
(const half2 *)beta,
batch,
H,
W,
n / 2,
shift_size,
window_size,
eps);
} else {
preln_layernorm_shift_partition<<<grid, blockSize, 0, stream>>>(
reinterpret_cast<half2 *>(out0),
reinterpret_cast<half2 *>(out1),
(const half2 *)input0,
(const half2 *)input1,
(const half2 *)gamma,
(const half2 *)beta,
batch,
H,
W,
n / 2,
shift_size,
window_size,
eps);
}
}
template <typename T>
static void convertAndCopy(const std::vector<float> &host, T *dev) {
T *host_ptr = new T[host.size()];
std::transform(host.begin(), host.end(), host_ptr, [](float x) {
return static_cast<T>(x);
});
cudaMemcpy(dev, host_ptr, sizeof(T) * host.size(), cudaMemcpyHostToDevice);
delete host_ptr;
}
void PrelnLnormShiftPartitionPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT {}
PrelnLnormShiftPartitionPluginDynamic::PrelnLnormShiftPartitionPluginDynamic(
const float *gamma,
const float *beta,
const int param_num,
int shift_size,
int window_size,
int input_resolution,
float eps,
bool with_fp16,
std::shared_ptr<void> gamma_dev,
std::shared_ptr<void> beta_dev)
: with_fp16_(with_fp16),
window_size_(window_size),
shift_size_(shift_size),
input_resolution_(input_resolution),
eps_(eps),
param_num_(param_num),
gamma_dev_(gamma_dev),
beta_dev_(beta_dev) {
beta_.resize(param_num);
gamma_.resize(param_num);
std::copy(gamma, gamma + param_num, gamma_.data());
std::copy(beta, beta + param_num, beta_.data());
int type_size = with_fp16 ? sizeof(half) : sizeof(float);
if (gamma_dev_ == nullptr) {
void *p;
cudaMalloc(reinterpret_cast<void **>(&p), param_num_ * type_size);
gamma_dev_.reset(p, [](void *ptr) { cudaFree(ptr); });
if (with_fp16)
convertAndCopy(gamma_, reinterpret_cast<half *>(p));
else
convertAndCopy(gamma_, reinterpret_cast<float *>(p));
}
if (beta_dev_ == nullptr) {
void *p;
cudaMalloc(reinterpret_cast<void **>(&p), param_num_ * type_size);
beta_dev_.reset(p, [](void *ptr) { cudaFree(ptr); });
if (with_fp16)
convertAndCopy(beta_, reinterpret_cast<half *>(p));
else
convertAndCopy(beta_, reinterpret_cast<float *>(p));
}
}
PrelnLnormShiftPartitionPluginDynamic::PrelnLnormShiftPartitionPluginDynamic(
void const *serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &beta_);
DeserializeValue(&serialData, &serialLength, &gamma_);
DeserializeValue(&serialData, &serialLength, &param_num_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
DeserializeValue(&serialData, &serialLength, &shift_size_);
DeserializeValue(&serialData, &serialLength, &window_size_);
DeserializeValue(&serialData, &serialLength, &input_resolution_);
DeserializeValue(&serialData, &serialLength, &eps_);
int type_size = with_fp16_ ? sizeof(half) : sizeof(float);
{
void *p;
cudaMalloc(reinterpret_cast<void **>(&p), param_num_ * type_size);
gamma_dev_.reset(p, [](void *ptr) { cudaFree(ptr); });
if (with_fp16_)
convertAndCopy(gamma_, reinterpret_cast<half *>(p));
else
convertAndCopy(gamma_, reinterpret_cast<float *>(p));
}
{
void *p;
cudaMalloc(reinterpret_cast<void **>(&p), param_num_ * type_size);
beta_dev_.reset(p, [](void *ptr) { cudaFree(ptr); });
if (with_fp16_)
convertAndCopy(beta_, reinterpret_cast<half *>(p));
else
convertAndCopy(beta_, reinterpret_cast<float *>(p));
}
}
bool PrelnLnormShiftPartitionPluginDynamic::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc *in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
return in.type == nvinfer1::DataType::kHALF &&
in.format == nvinfer1::TensorFormat::kLINEAR;
} else {
return in.type == nvinfer1::DataType::kFLOAT &&
in.format == nvinfer1::TensorFormat::kLINEAR;
}
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType PrelnLnormShiftPartitionPluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
return input_types[0];
}
nvinfer1::DimsExprs PrelnLnormShiftPartitionPluginDynamic::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs *inputs,
int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
if (output_index == 0) return inputs[0];
nvinfer1::DimsExprs ret;
ret.nbDims = 3;
ret.d[0] = expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*inputs[0].d[0],
*inputs[0].d[1]),
*expr_builder.constant(window_size_ * window_size_));
ret.d[1] = expr_builder.constant(window_size_ * window_size_);
ret.d[2] = inputs[0].d[2];
return ret;
}
int PrelnLnormShiftPartitionPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
const auto &input_dims = input_desc[0].dims;
auto input_type = input_desc[0].type;
int batch = input_dims.d[0];
int emb_dim = input_dims.d[2];
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(3)
<< "TRT Plugin DataType selected. PreLayernormShiftPartition-->fp32";
invokePrelnLayernormShiftPartition(
reinterpret_cast<float *>(outputs[0]),
reinterpret_cast<float *>(outputs[1]),
reinterpret_cast<const float *>(inputs[0]),
reinterpret_cast<const float *>(inputs[1]),
reinterpret_cast<const float *>(gamma_dev_.get()),
reinterpret_cast<const float *>(beta_dev_.get()),
batch,
input_resolution_,
input_resolution_,
emb_dim,
shift_size_,
window_size_,
eps_,
stream);
} else if (input_type == nvinfer1::DataType::kHALF) {
VLOG(3)
<< "TRT Plugin DataType selected. PreLayernormShiftPartition-->half";
invokePrelnLayernormShiftPartition(
reinterpret_cast<half *>(outputs[0]),
reinterpret_cast<half *>(outputs[1]),
reinterpret_cast<const half *>(inputs[0]),
reinterpret_cast<const half *>(inputs[1]),
reinterpret_cast<const half *>(gamma_dev_.get()),
reinterpret_cast<const half *>(beta_dev_.get()),
batch,
input_resolution_,
input_resolution_,
emb_dim,
shift_size_,
window_size_,
eps_,
stream);
}
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class PrelnLnormShiftPartitionPluginDynamic : public DynamicPluginTensorRT {
public:
PrelnLnormShiftPartitionPluginDynamic(
const float* gamma,
const float* beta,
const int param_num,
int shift_size,
int window_size,
int input_resolution,
float eps,
bool with_fp16,
std::shared_ptr<void> gamma_dev = nullptr,
std::shared_ptr<void> beta_dev = nullptr);
PrelnLnormShiftPartitionPluginDynamic(void const* serialData,
size_t serialLength);
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new PrelnLnormShiftPartitionPluginDynamic(gamma_.data(),
beta_.data(),
beta_.size(),
shift_size_,
window_size_,
input_resolution_,
eps_,
with_fp16_,
gamma_dev_,
beta_dev_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "prelnlnorm_shift_partition_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 2; }
int initialize() TRT_NOEXCEPT override { return 0; }
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(beta_) + SerializedSize(gamma_) +
SerializedSize(param_num_) + SerializedSize(with_fp16_) +
SerializedSize(shift_size_) + SerializedSize(window_size_) +
SerializedSize(input_resolution_) + SerializedSize(eps_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, beta_);
SerializeValue(&buffer, gamma_);
SerializeValue(&buffer, param_num_);
SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, shift_size_);
SerializeValue(&buffer, window_size_);
SerializeValue(&buffer, input_resolution_);
SerializeValue(&buffer, eps_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
private:
bool with_fp16_;
std::vector<float> gamma_;
std::vector<float> beta_;
int window_size_;
int shift_size_;
int input_resolution_;
int param_num_;
float eps_;
std::shared_ptr<void> gamma_dev_;
std::shared_ptr<void> beta_dev_;
};
class PrelnLnormShiftPartitionPluginDynamicCreator
: public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "prelnlnorm_shift_partition_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
return new PrelnLnormShiftPartitionPluginDynamic(serial_data,
serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(PrelnLnormShiftPartitionPluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -173,6 +173,8 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_reshape2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_preln_layernorm_x_fuse_pass PROPERTIES TIMEOUT
240)
set_tests_properties(test_trt_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT
240)
set_tests_properties(test_trt_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
import unittest
import hypothesis.strategies as st
class TestLayernormShiftPartitionPass(PassAutoScanTest):
#
# | | | |
# other_op1 other_op2 other_op1 other_op2
# | | fuse \ /
# |------elementwise_add -> preln_layernorm_shift_partition
# | | | |
# other_op4 layernorm_shift_partition other_op4 other_op3
# |
# other_op3
def sample_predictor_configs(self, program_config):
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input_data_x": [1, 9, 96],
"input_data_y": [1, 9, 96],
},
{
"input_data_x": [4, 3136, 768],
"input_data_y": [4, 3136, 768],
},
{
"input_data_x": [1, 784, 384],
"input_data_y": [1, 784, 384],
},
)
yield config, ['preln_layernorm_shift_partition'], (1e-5, 1e-5)
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input_data_x": [1, 9, 96],
"input_data_y": [1, 9, 96],
},
{
"input_data_x": [4, 3136, 768],
"input_data_y": [4, 3136, 768],
},
{
"input_data_x": [1, 784, 384],
"input_data_y": [1, 784, 384],
},
)
yield config, ['preln_layernorm_shift_partition'], (1e-2, 1e-2)
def sample_program_config(self, draw):
axis = [0, 1, 3, 2, 4, 5]
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
# begin_norm_axis has to be 2
begin_norm_axis = 2
batch_size = draw(st.integers(min_value=1, max_value=4))
window_size = draw(st.sampled_from([3, 5, 7]))
move_shape = draw(st.integers(min_value=1, max_value=8))
dim = draw(st.sampled_from([96, 192, 384, 768]))
def generate_input(attrs):
return np.random.random(
[attrs[1]["batch_size"], *attrs[1]["input_dim"]]
).astype(np.float32)
def generate_weight(attrs):
return np.random.random(attrs[1]['input_dim'][-1]).astype(
np.float32
)
attrs = [
{
'begin_norm_axis': begin_norm_axis,
'epsilon': epsilon,
},
{
'batch_size': batch_size,
'input_dim': [(window_size * move_shape) ** 2, dim],
},
{
'axis': axis,
'input_resolution': window_size * move_shape,
'move_shape': move_shape,
'window_size': window_size,
},
]
elementwise_add_op = OpConfig(
type="elementwise_add",
inputs={"X": ["input_data_x"], "Y": ["input_data_y"]},
outputs={"Out": ["ele_out"]},
attrs={"axis": -1},
)
layer_norm_op = OpConfig(
type="layer_norm",
inputs={
"X": ["ele_out"],
"Bias": ["layer_norm_bias"],
"Scale": ["layer_norm_scale"],
},
outputs={
"Y": ["layer_norm_output1"],
"Mean": ["layer_norm_output2"],
"Variance": ["layer_norm_output3"],
},
attrs={
"begin_norm_axis": attrs[0]["begin_norm_axis"],
"epsilon": attrs[0]["epsilon"],
},
)
reshape_op2 = OpConfig(
type="reshape2",
inputs={
"X": ["layer_norm_output1"],
},
outputs={
"Out": ["reshape_output2"],
"XShape": ["reshape_output2_xshape"],
},
attrs={
'shape': [
-1,
attrs[2]["input_resolution"],
attrs[2]["input_resolution"],
attrs[1]["input_dim"][-1],
]
},
)
reshape_op3 = OpConfig(
type="reshape2",
inputs={
"X": ["reshape_output2"],
},
outputs={
"Out": ["reshape_output3"],
"XShape": ["reshape_output3_xshape"],
},
attrs={
'shape': [
-1,
attrs[2]["move_shape"],
attrs[2]["window_size"],
attrs[2]["move_shape"],
attrs[2]["window_size"],
attrs[1]["input_dim"][-1],
]
},
)
transpose_op4 = OpConfig(
type='transpose2',
inputs={
"X": ["reshape_output3"],
},
outputs={"Out": ["transpose_output4"]},
attrs={"axis": attrs[2]['axis']},
)
reshape_op5 = OpConfig(
type="reshape2",
inputs={
"X": ["transpose_output4"],
},
outputs={
"Out": ["reshape_output5"],
"XShape": ["reshape_output5_xshape"],
},
attrs={
'shape': [
-1,
attrs[2]["window_size"],
attrs[2]["window_size"],
attrs[1]["input_dim"][-1],
]
},
)
reshape_op6 = OpConfig(
type="reshape2",
inputs={
"X": ["reshape_output5"],
},
outputs={
"Out": ["reshape_output6"],
"XShape": ["reshape_output6_xshape"],
},
attrs={
'shape': [
-1,
attrs[2]["window_size"] ** 2,
attrs[1]["input_dim"][-1],
]
},
)
program_config = ProgramConfig(
ops=[
elementwise_add_op,
layer_norm_op,
reshape_op2,
reshape_op3,
transpose_op4,
reshape_op5,
reshape_op6,
],
weights={
"layer_norm_bias": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
"layer_norm_scale": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
},
inputs={
"input_data_x": TensorConfig(
data_gen=partial(generate_input, attrs)
),
"input_data_y": TensorConfig(
data_gen=partial(generate_input, attrs)
),
},
outputs=["ele_out", "reshape_output6"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["preln_layernorm_x_fuse_pass"],
max_duration=250,
min_success_num=50,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册