未验证 提交 d17d0cd1 编写于 作者: W wenbin 提交者: GitHub

Preln_Layernorm_Shift_Partition (#47099)

* prelnlayernorm_shift

* add ut

* remove paddle_enforce

* remove useless

* add UT

* remove UT

* add UT

* set timeout
上级 c1c2be2d
......@@ -129,6 +129,7 @@ if(WITH_TENSORRT)
pass_library(remove_padding_recover_padding_pass inference)
pass_library(delete_remove_padding_recover_padding_pass inference)
pass_library(layernorm_shift_partition_fuse_pass inference)
pass_library(preln_layernorm_x_fuse_pass inference)
endif()
if(WITH_TENSORRT AND NOT WIN32)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct PrelnLayerNormX : public PatternBase {
PrelnLayerNormX(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_layernorm_x") {}
void operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(elementwise_bias);
PATTERN_DECL_NODE(elementwise0);
PATTERN_DECL_NODE(elementwise1);
PATTERN_DECL_NODE(layer_norm);
// declare variable node's name
PATTERN_DECL_NODE(elementwise0_out);
PATTERN_DECL_NODE(elementwise1_out);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
};
void PrelnLayerNormX::operator()(PDNode *x, PDNode *y) {
auto *elementwise1 =
pattern->NewNode(elementwise1_repr())->assert_is_op("elementwise_add");
auto *elementwise1_out_var =
pattern->NewNode(elementwise1_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("layernorm_shift_partition", "X");
elementwise1->LinksFrom({x, y}).LinksTo({elementwise1_out_var});
// Create nodes for layer_norm op.
auto *layer_norm = pattern->NewNode(layer_norm_repr())
->assert_is_op("layernorm_shift_partition");
auto *layer_norm_bias_var =
pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layernorm_shift_partition", "Bias");
auto *layer_norm_scale_var =
pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layernorm_shift_partition", "Scale");
auto *layer_norm_out_var =
pattern->NewNode(layer_norm_out_repr())
->AsOutput()
->assert_is_op_output("layernorm_shift_partition", "Y");
// Add links for layer_norm op.
layer_norm
->LinksFrom(
{elementwise1_out_var, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo({layer_norm_out_var});
}
} // namespace patterns
int PrelnLayerNormXFusePass::ApplyPattern(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_layernorm_x_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
PDNode *x = nullptr;
PDNode *y = nullptr;
x = gpd.mutable_pattern()
->NewNode("preln_layernorm_x_fuse/x")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "X");
y = gpd.mutable_pattern()
->NewNode("preln_layernorm_x_fuse/y")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "Y");
patterns::PrelnLayerNormX fused_pattern(gpd.mutable_pattern(),
"preln_layernorm_x_fuse");
fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
VLOG(4) << "handle preln layernorm x fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise1, elementwise1, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise1_out, elementwise1_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "preln_layernorm_x_fuse pass in op compat failed.";
return;
}
static int cnt = 0;
if (cnt++ > 0) {
// return;
}
std::unordered_set<const Node *> del_node_set;
// Create an PrelnLayerNormX op node
OpDesc new_desc(*layer_norm->Op());
new_desc.SetType("preln_layernorm_shift_partition");
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetOutput("Out_0", {elementwise1_out->Name()});
new_desc.SetOutput("Out_1", {layer_norm_out->Name()});
new_desc.RemoveOutput("Y");
new_desc.Flush();
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(elementwise1);
del_node_set.insert(layer_norm);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(layer_norm_scale, fused_node);
IR_NODE_LINK_TO(layer_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, layer_norm_out);
IR_NODE_LINK_TO(fused_node, elementwise1_out);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void PrelnLayerNormXFusePass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init("preln_layernorm_x_fuse", graph);
int found_subgraph_count = ApplyPattern(graph);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(preln_layernorm_x_fuse_pass,
paddle::framework::ir::PrelnLayerNormXFusePass);
REGISTER_PASS_CAPABILITY(preln_layernorm_x_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"elementwise_add", 1));
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
//
// | | | |
// other_op1 other_op2 other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> preln_layernorm_shift_partition
// | | | |
// other_op4 layernorm_shift_partition other_op4 other_op3
// |
// other_op3
class Graph;
class PrelnLayerNormXFusePass : public FusePassBase {
public:
PrelnLayerNormXFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1, 2})
.End();
}
virtual ~PrelnLayerNormXFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
int ApplyPattern(ir::Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -2271,6 +2271,7 @@ USE_TRT_CONVERTER(shape)
USE_TRT_CONVERTER(fill_constant)
USE_TRT_CONVERTER(fused_token_prune)
USE_TRT_CONVERTER(layernorm_shift_partition)
USE_TRT_CONVERTER(preln_layernorm_shift_partition)
USE_TRT_CONVERTER(merge_layernorm)
USE_TRT_CONVERTER(generic_plugin_creater)
USE_TRT_CONVERTER(custom_plugin_creater)
......
......@@ -113,6 +113,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"layernorm_shift_partition_fuse_pass", //
"merge_layernorm_fuse_pass", //
"preln_residual_bias_fuse_pass", //
"preln_layernorm_x_fuse_pass", //
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
......
......@@ -77,6 +77,7 @@ list(
fill_constant_op.cc
fused_token_prune_op.cc
layernorm_shift_partition_op.cc
preln_layernorm_shift_partition_op.cc
merge_layernorm_op.cc
generic_and_custom_plugin_creater.cc
fused_lookup_tables_op.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class PrelnLayerNormShiftPartitionOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a fluid preln_layernorm_shift_partition op to tensorrt "
"preln_layernorm_shift_partition plugin";
framework::OpDesc op_desc(op, nullptr);
auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Y = engine_->GetITensor(op_desc.Input("Y").front());
std::vector<nvinfer1::ITensor*> inputs{X, Y};
auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front());
auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front());
const float eps = op_desc.HasAttr("epsilon")
? PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"))
: 1e-5f;
const int window_size =
PADDLE_GET_CONST(int, op_desc.GetAttr("window_size"));
const int input_resolution =
PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution"));
const int shift_size =
op_desc.HasAttr("shift_size")
? PADDLE_GET_CONST(int, op_desc.GetAttr("shift_size"))
: 0;
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();
auto bias_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *Bias_t);
auto scale_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Scale").front(), *Scale_t);
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
nvinfer1::ILayer* layernorm_layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::PrelnLnormShiftPartitionPluginDynamic* plugin =
new plugin::PrelnLnormShiftPartitionPluginDynamic(
static_cast<const float*>(scale_weight.get().values),
static_cast<const float*>(bias_weight.get().values),
bias_weight.get().count,
shift_size,
window_size,
input_resolution,
eps,
with_fp16);
layernorm_layer =
engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin);
}
std::vector<std::string> output_names;
output_names.emplace_back(op_desc.Output("Out_0").front());
output_names.emplace_back(op_desc.Output("Out_1").front());
RreplenishLayerAndOutput(layernorm_layer,
"preln_layernorm_shift_partition",
output_names,
test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(preln_layernorm_shift_partition,
PrelnLayerNormShiftPartitionOpConverter);
......@@ -2100,6 +2100,15 @@ struct SimpleOpTypeSetTeller : public Teller {
return false;
}
}
if (op_type == "preln_layernorm_shift_partition") {
if (!with_dynamic_shape) {
VLOG(3) << "the layernorm_shift_partition does not support "
"static shape yet";
return false;
}
}
if (op_type == "merge_layernorm") {
if (!with_dynamic_shape) {
VLOG(3) << "The merge_layernorm op does not support "
......@@ -2259,9 +2268,11 @@ struct SimpleOpTypeSetTeller : public Teller {
"squeeze2",
"unsqueeze2",
"layernorm_shift_partition",
"preln_layernorm_shift_partition",
"lookup_table",
"lookup_table_v2",
"expand_v2"};
std::unordered_set<std::string> teller_set{
"mul",
"matmul",
......@@ -2376,6 +2387,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"unsqueeze2",
"fused_token_prune",
"layernorm_shift_partition",
"preln_layernorm_shift_partition",
"merge_layernorm",
"lookup_table",
"lookup_table_v2",
......
......@@ -33,9 +33,11 @@ list(
preln_residual_bias_plugin.cu
fused_token_prune_op_plugin.cu
layernorm_shift_partition_op.cu
prelnlayernorm_shift_partition_op.cu
merge_layernorm_op_plugin.cu
generic_plugin.cu
lookup_table.cu)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32)
list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu
many_emb_Layernorm_varseqlen_kernelMTron.cu
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class PrelnLnormShiftPartitionPluginDynamic : public DynamicPluginTensorRT {
public:
PrelnLnormShiftPartitionPluginDynamic(
const float* gamma,
const float* beta,
const int param_num,
int shift_size,
int window_size,
int input_resolution,
float eps,
bool with_fp16,
std::shared_ptr<void> gamma_dev = nullptr,
std::shared_ptr<void> beta_dev = nullptr);
PrelnLnormShiftPartitionPluginDynamic(void const* serialData,
size_t serialLength);
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new PrelnLnormShiftPartitionPluginDynamic(gamma_.data(),
beta_.data(),
beta_.size(),
shift_size_,
window_size_,
input_resolution_,
eps_,
with_fp16_,
gamma_dev_,
beta_dev_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "prelnlnorm_shift_partition_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 2; }
int initialize() TRT_NOEXCEPT override { return 0; }
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(beta_) + SerializedSize(gamma_) +
SerializedSize(param_num_) + SerializedSize(with_fp16_) +
SerializedSize(shift_size_) + SerializedSize(window_size_) +
SerializedSize(input_resolution_) + SerializedSize(eps_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, beta_);
SerializeValue(&buffer, gamma_);
SerializeValue(&buffer, param_num_);
SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, shift_size_);
SerializeValue(&buffer, window_size_);
SerializeValue(&buffer, input_resolution_);
SerializeValue(&buffer, eps_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
private:
bool with_fp16_;
std::vector<float> gamma_;
std::vector<float> beta_;
int window_size_;
int shift_size_;
int input_resolution_;
int param_num_;
float eps_;
std::shared_ptr<void> gamma_dev_;
std::shared_ptr<void> beta_dev_;
};
class PrelnLnormShiftPartitionPluginDynamicCreator
: public TensorRTPluginCreator {
public:
const char* getPluginName() const TRT_NOEXCEPT override {
return "prelnlnorm_shift_partition_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
return new PrelnLnormShiftPartitionPluginDynamic(serial_data,
serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(PrelnLnormShiftPartitionPluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -173,6 +173,8 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_reshape2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_preln_layernorm_x_fuse_pass PROPERTIES TIMEOUT
240)
set_tests_properties(test_trt_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT
240)
set_tests_properties(test_trt_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
import unittest
import hypothesis.strategies as st
class TestLayernormShiftPartitionPass(PassAutoScanTest):
#
# | | | |
# other_op1 other_op2 other_op1 other_op2
# | | fuse \ /
# |------elementwise_add -> preln_layernorm_shift_partition
# | | | |
# other_op4 layernorm_shift_partition other_op4 other_op3
# |
# other_op3
def sample_predictor_configs(self, program_config):
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input_data_x": [1, 9, 96],
"input_data_y": [1, 9, 96],
},
{
"input_data_x": [4, 3136, 768],
"input_data_y": [4, 3136, 768],
},
{
"input_data_x": [1, 784, 384],
"input_data_y": [1, 784, 384],
},
)
yield config, ['preln_layernorm_shift_partition'], (1e-5, 1e-5)
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input_data_x": [1, 9, 96],
"input_data_y": [1, 9, 96],
},
{
"input_data_x": [4, 3136, 768],
"input_data_y": [4, 3136, 768],
},
{
"input_data_x": [1, 784, 384],
"input_data_y": [1, 784, 384],
},
)
yield config, ['preln_layernorm_shift_partition'], (1e-2, 1e-2)
def sample_program_config(self, draw):
axis = [0, 1, 3, 2, 4, 5]
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
# begin_norm_axis has to be 2
begin_norm_axis = 2
batch_size = draw(st.integers(min_value=1, max_value=4))
window_size = draw(st.sampled_from([3, 5, 7]))
move_shape = draw(st.integers(min_value=1, max_value=8))
dim = draw(st.sampled_from([96, 192, 384, 768]))
def generate_input(attrs):
return np.random.random(
[attrs[1]["batch_size"], *attrs[1]["input_dim"]]
).astype(np.float32)
def generate_weight(attrs):
return np.random.random(attrs[1]['input_dim'][-1]).astype(
np.float32
)
attrs = [
{
'begin_norm_axis': begin_norm_axis,
'epsilon': epsilon,
},
{
'batch_size': batch_size,
'input_dim': [(window_size * move_shape) ** 2, dim],
},
{
'axis': axis,
'input_resolution': window_size * move_shape,
'move_shape': move_shape,
'window_size': window_size,
},
]
elementwise_add_op = OpConfig(
type="elementwise_add",
inputs={"X": ["input_data_x"], "Y": ["input_data_y"]},
outputs={"Out": ["ele_out"]},
attrs={"axis": -1},
)
layer_norm_op = OpConfig(
type="layer_norm",
inputs={
"X": ["ele_out"],
"Bias": ["layer_norm_bias"],
"Scale": ["layer_norm_scale"],
},
outputs={
"Y": ["layer_norm_output1"],
"Mean": ["layer_norm_output2"],
"Variance": ["layer_norm_output3"],
},
attrs={
"begin_norm_axis": attrs[0]["begin_norm_axis"],
"epsilon": attrs[0]["epsilon"],
},
)
reshape_op2 = OpConfig(
type="reshape2",
inputs={
"X": ["layer_norm_output1"],
},
outputs={
"Out": ["reshape_output2"],
"XShape": ["reshape_output2_xshape"],
},
attrs={
'shape': [
-1,
attrs[2]["input_resolution"],
attrs[2]["input_resolution"],
attrs[1]["input_dim"][-1],
]
},
)
reshape_op3 = OpConfig(
type="reshape2",
inputs={
"X": ["reshape_output2"],
},
outputs={
"Out": ["reshape_output3"],
"XShape": ["reshape_output3_xshape"],
},
attrs={
'shape': [
-1,
attrs[2]["move_shape"],
attrs[2]["window_size"],
attrs[2]["move_shape"],
attrs[2]["window_size"],
attrs[1]["input_dim"][-1],
]
},
)
transpose_op4 = OpConfig(
type='transpose2',
inputs={
"X": ["reshape_output3"],
},
outputs={"Out": ["transpose_output4"]},
attrs={"axis": attrs[2]['axis']},
)
reshape_op5 = OpConfig(
type="reshape2",
inputs={
"X": ["transpose_output4"],
},
outputs={
"Out": ["reshape_output5"],
"XShape": ["reshape_output5_xshape"],
},
attrs={
'shape': [
-1,
attrs[2]["window_size"],
attrs[2]["window_size"],
attrs[1]["input_dim"][-1],
]
},
)
reshape_op6 = OpConfig(
type="reshape2",
inputs={
"X": ["reshape_output5"],
},
outputs={
"Out": ["reshape_output6"],
"XShape": ["reshape_output6_xshape"],
},
attrs={
'shape': [
-1,
attrs[2]["window_size"] ** 2,
attrs[1]["input_dim"][-1],
]
},
)
program_config = ProgramConfig(
ops=[
elementwise_add_op,
layer_norm_op,
reshape_op2,
reshape_op3,
transpose_op4,
reshape_op5,
reshape_op6,
],
weights={
"layer_norm_bias": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
"layer_norm_scale": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
},
inputs={
"input_data_x": TensorConfig(
data_gen=partial(generate_input, attrs)
),
"input_data_y": TensorConfig(
data_gen=partial(generate_input, attrs)
),
},
outputs=["ele_out", "reshape_output6"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["preln_layernorm_x_fuse_pass"],
max_duration=250,
min_success_num=50,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册