未验证 提交 9e08633c 编写于 作者: W Wang Bojun 提交者: GitHub

Layernorm shift partition enhance (#46816)

* first version of ln_s_p with s>0

* refine and UT

* pass opt draft

* pass opt

* code refine

* code-style

* bug fix

* fix ci test

* code style
上级 f0af2708
...@@ -3553,8 +3553,20 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() { ...@@ -3553,8 +3553,20 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() {
}); });
auto reshape1_out = pattern->NewNode(reshape1_out_repr()) auto reshape1_out = pattern->NewNode(reshape1_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("reshape2", "X")
->assert_is_op_output("reshape2", "Out"); ->assert_is_op_output("reshape2", "Out");
PDNode *roll1_op = nullptr;
PDNode *roll1_out = nullptr;
if (!with_roll_) {
reshape1_out->assert_is_op_input("reshape2", "X");
} else {
reshape1_out->assert_is_op_input("roll", "X");
roll1_op = pattern->NewNode(roll1_op_repr())->assert_is_op("roll");
roll1_out = pattern->NewNode(roll1_out_repr())
->AsIntermediate()
->assert_is_op_output("roll", "Out")
->assert_is_op_input("reshape2", "X");
}
auto reshape2_op = auto reshape2_op =
pattern->NewNode(reshape2_op_repr()) pattern->NewNode(reshape2_op_repr())
->assert_is_op("reshape2") ->assert_is_op("reshape2")
...@@ -3564,6 +3576,7 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() { ...@@ -3564,6 +3576,7 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() {
node->Op()->GetAttr("shape")) node->Op()->GetAttr("shape"))
.size() == 6); .size() == 6);
}); });
auto reshape2_out = pattern->NewNode(reshape2_out_repr()) auto reshape2_out = pattern->NewNode(reshape2_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_op_input("transpose2", "X") ->assert_is_op_input("transpose2", "X")
...@@ -3612,7 +3625,12 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() { ...@@ -3612,7 +3625,12 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() {
layer_norm_op->LinksFrom({layer_norm_in, layer_norm_bias, layer_norm_scale}) layer_norm_op->LinksFrom({layer_norm_in, layer_norm_bias, layer_norm_scale})
.LinksTo({layer_norm_out}); .LinksTo({layer_norm_out});
reshape1_op->LinksFrom({layer_norm_out}).LinksTo({reshape1_out}); reshape1_op->LinksFrom({layer_norm_out}).LinksTo({reshape1_out});
reshape2_op->LinksFrom({reshape1_out}).LinksTo({reshape2_out}); if (!with_roll_) {
reshape2_op->LinksFrom({reshape1_out}).LinksTo({reshape2_out});
} else {
roll1_op->LinksFrom({reshape1_out}).LinksTo({roll1_out});
reshape2_op->LinksFrom({roll1_out}).LinksTo({reshape2_out});
}
transpose_op->LinksFrom({reshape2_out}).LinksTo({transpose_out}); transpose_op->LinksFrom({reshape2_out}).LinksTo({transpose_out});
reshape3_op->LinksFrom({transpose_out}).LinksTo({reshape3_out}); reshape3_op->LinksFrom({transpose_out}).LinksTo({reshape3_out});
reshape4_op->LinksFrom({reshape3_out}).LinksTo({reshape4_out}); reshape4_op->LinksFrom({reshape3_out}).LinksTo({reshape4_out});
......
...@@ -1914,15 +1914,17 @@ struct LayerNorm : public PatternBase { ...@@ -1914,15 +1914,17 @@ struct LayerNorm : public PatternBase {
// //
// \brief Pattern looking for subgraph representing layernorm_shift_partition // \brief Pattern looking for subgraph representing layernorm_shift_partition
// operation. // operation with shift_size = 0.
// //
struct LayernormShiftPartitionPattern : public PatternBase { struct LayernormShiftPartitionPattern : public PatternBase {
LayernormShiftPartitionPattern(PDPattern* pattern, LayernormShiftPartitionPattern(PDPattern* pattern,
const std::string& name_scope) const std::string& name_scope,
: PatternBase(pattern, name_scope, "layernorm_shift_partition") {} bool with_roll)
: PatternBase(pattern, name_scope, "layernorm_shift_partition"),
with_roll_(with_roll) {}
PDNode* operator()(); PDNode* operator()();
bool with_roll_;
PATTERN_DECL_NODE(layer_norm_in); PATTERN_DECL_NODE(layer_norm_in);
PATTERN_DECL_NODE(layer_norm_op); PATTERN_DECL_NODE(layer_norm_op);
PATTERN_DECL_NODE(layer_norm_bias); PATTERN_DECL_NODE(layer_norm_bias);
...@@ -1930,6 +1932,10 @@ struct LayernormShiftPartitionPattern : public PatternBase { ...@@ -1930,6 +1932,10 @@ struct LayernormShiftPartitionPattern : public PatternBase {
PATTERN_DECL_NODE(layer_norm_out); PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(reshape1_op); PATTERN_DECL_NODE(reshape1_op);
PATTERN_DECL_NODE(reshape1_out); PATTERN_DECL_NODE(reshape1_out);
// optional op roll
PATTERN_DECL_NODE(roll1_op);
PATTERN_DECL_NODE(roll1_out);
PATTERN_DECL_NODE(reshape2_op); PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out); PATTERN_DECL_NODE(reshape2_out);
PATTERN_DECL_NODE(transpose_op); PATTERN_DECL_NODE(transpose_op);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
...@@ -85,22 +86,33 @@ LayerNormShiftPartitionFusePass::LayerNormShiftPartitionFusePass() { ...@@ -85,22 +86,33 @@ LayerNormShiftPartitionFusePass::LayerNormShiftPartitionFusePass() {
.AddAttr("axis") .AddAttr("axis")
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End(); .End();
AddOpCompat(OpCompat("roll"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int64_t>>()
.End()
.AddAttr("shifts")
.IsType<std::vector<int64_t>>()
.End();
} }
void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { int LayerNormShiftPartitionFusePass::ApplyPattern(ir::Graph* graph,
bool with_roll) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, graph,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input graph of LayerNormShiftPartitionFusePass should not be " "The input graph of LayerNormShiftPartitionFusePass should not be "
"nullptr.")); "nullptr."));
FusePassBase::Init(scope_name_, graph); FusePassBase::Init(scope_name_, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::LayernormShiftPartitionPattern shift_patition_pattern( patterns::LayernormShiftPartitionPattern shift_patition_pattern(
gpd.mutable_pattern(), scope_name_); gpd.mutable_pattern(), scope_name_, with_roll);
shift_patition_pattern(); shift_patition_pattern();
int found_count = 0; int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
...@@ -108,8 +120,13 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -108,8 +120,13 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
LOG(WARNING) << "layernorm_shift_partition_fuse in op compat failed."; LOG(WARNING) << "layernorm_shift_partition_fuse in op compat failed.";
return; return;
} }
if (with_roll) {
VLOG(4) << "layernorm_shift_partition_fuse pass"; VLOG(4)
<< "layernorm_shift_partition_fuse pass, shift_size>0, with roll op";
} else {
VLOG(4) << "layernorm_shift_partition_fuse pass, shift_size=0, without "
"roll op";
}
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_in, layer_norm_in, shift_patition_pattern); layer_norm_in, layer_norm_in, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
...@@ -123,6 +140,15 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -123,6 +140,15 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(reshape1_op, reshape1_op, shift_patition_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape1_op, reshape1_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
reshape1_out, reshape1_out, shift_patition_pattern); reshape1_out, reshape1_out, shift_patition_pattern);
Node* roll1_op = nullptr;
Node* roll1_out = nullptr;
if (with_roll) {
GET_IR_NODE_FROM_SUBGRAPH(tmp_roll1_op, roll1_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
tmp_roll1_out, roll1_out, shift_patition_pattern);
roll1_op = tmp_roll1_op;
roll1_out = tmp_roll1_out;
}
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, shift_patition_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
reshape2_out, reshape2_out, shift_patition_pattern); reshape2_out, reshape2_out, shift_patition_pattern);
...@@ -136,6 +162,21 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -136,6 +162,21 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(reshape4_op, reshape4_op, shift_patition_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape4_op, reshape4_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
reshape4_out, reshape4_out, shift_patition_pattern); reshape4_out, reshape4_out, shift_patition_pattern);
std::unordered_set<const Node*> del_node_set = {layer_norm_op,
layer_norm_out,
reshape1_op,
reshape1_out,
reshape2_op,
reshape2_out,
transpose_op,
transpose_out,
reshape3_op,
reshape3_out,
reshape4_op};
if (with_roll) {
del_node_set.insert(roll1_op);
del_node_set.insert(roll1_out);
}
std::vector<int> shape_atr1 = std::vector<int> shape_atr1 =
PADDLE_GET_CONST(std::vector<int>, reshape1_op->Op()->GetAttr("shape")); PADDLE_GET_CONST(std::vector<int>, reshape1_op->Op()->GetAttr("shape"));
...@@ -165,7 +206,20 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -165,7 +206,20 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
if (window_size < 0 || input_resolution < 0) { if (window_size < 0 || input_resolution < 0) {
return; return;
} }
int shift_size = 0;
if (with_roll) {
std::vector<int64_t> roll_axis = PADDLE_GET_CONST(
std::vector<int64_t>, roll1_op->Op()->GetAttr("axis"));
std::vector<int64_t> roll_shifts = PADDLE_GET_CONST(
std::vector<int64_t>, roll1_op->Op()->GetAttr("shifts"));
if (roll_axis.size() != 2 || roll_axis[0] != 1 || roll_axis[1] != 2) {
return;
}
if (roll_shifts.size() != 2 || roll_shifts[0] != roll_shifts[1]) {
return;
}
shift_size = static_cast<int>(-roll_shifts[0]);
}
OpDesc new_op_desc; OpDesc new_op_desc;
new_op_desc.SetType("layernorm_shift_partition"); new_op_desc.SetType("layernorm_shift_partition");
new_op_desc.SetInput("X", {layer_norm_in->Name()}); new_op_desc.SetInput("X", {layer_norm_in->Name()});
...@@ -176,6 +230,7 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -176,6 +230,7 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
new_op_desc.SetAttr("begin_norm_axis", new_op_desc.SetAttr("begin_norm_axis",
layer_norm_op->Op()->GetAttr("begin_norm_axis")); layer_norm_op->Op()->GetAttr("begin_norm_axis"));
new_op_desc.SetAttr("window_size", window_size); new_op_desc.SetAttr("window_size", window_size);
new_op_desc.SetAttr("shift_size", shift_size);
new_op_desc.SetAttr("input_resolution", input_resolution); new_op_desc.SetAttr("input_resolution", input_resolution);
new_op_desc.Flush(); new_op_desc.Flush();
...@@ -185,22 +240,19 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -185,22 +240,19 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
IR_NODE_LINK_TO(layer_norm_bias, layernorm_shift_partition); IR_NODE_LINK_TO(layer_norm_bias, layernorm_shift_partition);
IR_NODE_LINK_TO(layer_norm_scale, layernorm_shift_partition); IR_NODE_LINK_TO(layer_norm_scale, layernorm_shift_partition);
IR_NODE_LINK_TO(layernorm_shift_partition, reshape4_out); IR_NODE_LINK_TO(layernorm_shift_partition, reshape4_out);
GraphSafeRemoveNodes(graph, GraphSafeRemoveNodes(graph, del_node_set);
{layer_norm_op,
layer_norm_out,
reshape1_op,
reshape1_out,
reshape2_op,
reshape2_out,
transpose_op,
transpose_out,
reshape3_op,
reshape3_out,
reshape4_op});
++found_count; ++found_count;
}; };
gpd(graph, handler); gpd(graph, handler);
return found_count;
}
void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
found_count += ApplyPattern(graph, true);
found_count += ApplyPattern(graph, false);
AddStatis(found_count); AddStatis(found_count);
} }
......
...@@ -37,6 +37,26 @@ namespace ir { ...@@ -37,6 +37,26 @@ namespace ir {
// reshape2 // reshape2
// | // |
// other_op // other_op
//
// or
//
// |
// layer_norm
// |
// reshape2
// |
// roll
// |
// reshape2 |
// | fuse layernorm_shift_patition
// transpose2 -> |
// | other_op
// reshape2
// |
// reshape2
// |
// other_op
class LayerNormShiftPartitionFusePass : public FusePassBase { class LayerNormShiftPartitionFusePass : public FusePassBase {
public: public:
LayerNormShiftPartitionFusePass(); LayerNormShiftPartitionFusePass();
...@@ -44,6 +64,7 @@ class LayerNormShiftPartitionFusePass : public FusePassBase { ...@@ -44,6 +64,7 @@ class LayerNormShiftPartitionFusePass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph *graph) const override; void ApplyImpl(ir::Graph *graph) const override;
int ApplyPattern(ir::Graph *graph, bool with_roll) const;
private: private:
const std::string scope_name_{"layernorm_shift_partition_fuse"}; const std::string scope_name_{"layernorm_shift_partition_fuse"};
......
...@@ -109,9 +109,9 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -109,9 +109,9 @@ const std::vector<std::string> kTRTSubgraphPasses({
"vit_attention_fuse_pass", // "vit_attention_fuse_pass", //
"trt_skip_layernorm_fuse_pass", // "trt_skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", // "preln_skip_layernorm_fuse_pass", //
"preln_residual_bias_fuse_pass", //
"layernorm_shift_partition_fuse_pass", // "layernorm_shift_partition_fuse_pass", //
// "set_transformer_input_convert_pass", // "preln_residual_bias_fuse_pass", //
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", //
"trt_squeeze2_matmul_fuse_pass", // "trt_squeeze2_matmul_fuse_pass", //
......
...@@ -40,11 +40,12 @@ class LayerNormShiftPartitionOpConverter : public OpConverter { ...@@ -40,11 +40,12 @@ class LayerNormShiftPartitionOpConverter : public OpConverter {
: 1e-5f; : 1e-5f;
const int window_size = const int window_size =
PADDLE_GET_CONST(int, op_desc.GetAttr("window_size")); PADDLE_GET_CONST(int, op_desc.GetAttr("window_size"));
const int shift_size = PADDLE_GET_CONST(int, op_desc.GetAttr("shift_size"));
const int input_resolution = const int input_resolution =
PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution")); PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution"));
// int shift_size = window_size / 2; // int shift_size = window_size / 2;
// shift_size = (input_resolution <= window_size) ? 0 : shift_size; // shift_size = (input_resolution <= window_size) ? 0 : shift_size;
int shift_size = 0; // int shift_size = 0;
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
Bias_v, Bias_v,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from auto_scan_test import PassAutoScanTest from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig from program_config import TensorConfig, ProgramConfig, OpConfig
import math
import numpy as np import numpy as np
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
from functools import partial from functools import partial
...@@ -219,5 +220,215 @@ class TestLayernormShiftPartitionPass(PassAutoScanTest): ...@@ -219,5 +220,215 @@ class TestLayernormShiftPartitionPass(PassAutoScanTest):
min_success_num=50) min_success_num=50)
class TestLayernormShiftPartition2Pass(PassAutoScanTest):
"""
|
layer_norm
|
reshape2
|
roll
|
reshape2
|
transpose2
|
reshape2
|
reshape2
|
"""
def sample_predictor_configs(self, program_config):
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
config.set_trt_dynamic_shape_info({
"input_data": [1, 9, 96],
}, {
"input_data": [4, 3136, 768],
}, {
"input_data": [1, 784, 384],
})
yield config, ['layernorm_shift_partition'], (1e-5, 1e-5)
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=4,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False)
config.set_trt_dynamic_shape_info({
"input_data": [1, 9, 96],
}, {
"input_data": [4, 3136, 768],
}, {
"input_data": [1, 784, 384],
})
yield config, ['layernorm_shift_partition'], (1e-3, 1e-3)
def sample_program_config(self, draw):
axis = [0, 1, 3, 2, 4, 5]
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
# begin_norm_axis has to be 2
begin_norm_axis = 2
batch_size = draw(st.integers(min_value=1, max_value=4))
window_size = draw(st.sampled_from([3, 5, 7]))
move_shape = draw(st.integers(min_value=1, max_value=8))
dim = draw(st.sampled_from([96, 192, 384, 768]))
def generate_input(attrs):
return np.random.random(
[attrs[1]["batch_size"],
*attrs[1]["input_dim"]]).astype(np.float32)
def generate_weight(attrs):
return np.random.random(attrs[1]['input_dim'][-1]).astype(
np.float32)
attrs = [{
'begin_norm_axis': begin_norm_axis,
'epsilon': epsilon,
}, {
'batch_size': batch_size,
'input_dim': [(window_size * move_shape)**2, dim],
}, {
'axis': axis,
'input_resolution': window_size * move_shape,
'move_shape': move_shape,
'window_size': window_size,
}]
layer_norm_op = OpConfig(type="layer_norm",
inputs={
"X": ["input_data"],
"Bias": ["layer_norm_bias"],
"Scale": ["layer_norm_scale"]
},
outputs={
"Y": ["layer_norm_output1"],
"Mean": ["layer_norm_output2"],
"Variance": ["layer_norm_output3"]
},
attrs={
"begin_norm_axis":
attrs[0]["begin_norm_axis"],
"epsilon": attrs[0]["epsilon"],
})
reshape_op2 = OpConfig(type="reshape2",
inputs={
"X": ["layer_norm_output1"],
},
outputs={
"Out": ["reshape_output2"],
"XShape": ["reshape_output2_xshape"],
},
attrs={
'shape': [
-1, attrs[2]["input_resolution"],
attrs[2]["input_resolution"],
attrs[1]["input_dim"][-1]
]
})
roll_op1 = OpConfig(type="roll",
inputs={"X": ["reshape_output2"]},
outputs={"Out": ["roll_output1"]},
attrs={
"axis": [1, 2],
"shifts": [
-math.floor(
(attrs[2]["window_size"]) / 2.0),
-math.floor((attrs[2]["window_size"]) / 2.0)
]
})
reshape_op3 = OpConfig(type="reshape2",
inputs={
"X": ["roll_output1"],
},
outputs={
"Out": ["reshape_output3"],
"XShape": ["reshape_output3_xshape"],
},
attrs={
'shape': [
-1, attrs[2]["move_shape"],
attrs[2]["window_size"],
attrs[2]["move_shape"],
attrs[2]["window_size"],
attrs[1]["input_dim"][-1]
]
})
transpose_op4 = OpConfig(type='transpose2',
inputs={
"X": ["reshape_output3"],
},
outputs={"Out": ["transpose_output4"]},
attrs={"axis": attrs[2]['axis']})
reshape_op5 = OpConfig(type="reshape2",
inputs={
"X": ["transpose_output4"],
},
outputs={
"Out": ["reshape_output5"],
"XShape": ["reshape_output5_xshape"],
},
attrs={
'shape': [
-1, attrs[2]["window_size"],
attrs[2]["window_size"],
attrs[1]["input_dim"][-1]
]
})
reshape_op6 = OpConfig(
type="reshape2",
inputs={
"X": ["reshape_output5"],
},
outputs={
"Out": ["reshape_output6"],
"XShape": ["reshape_output6_xshape"],
},
attrs={
'shape':
[-1, attrs[2]["window_size"]**2, attrs[1]["input_dim"][-1]]
})
program_config = ProgramConfig(
ops=[
layer_norm_op, reshape_op2, roll_op1, reshape_op3,
transpose_op4, reshape_op5, reshape_op6
],
weights={
"layer_norm_bias":
TensorConfig(data_gen=partial(generate_weight, attrs)),
"layer_norm_scale":
TensorConfig(data_gen=partial(generate_weight, attrs))
},
inputs={
"input_data":
TensorConfig(data_gen=partial(generate_input, attrs)),
},
outputs=["reshape_output6"])
return program_config
def test(self):
self.run_and_statis(quant=False,
max_examples=50,
passes=["layernorm_shift_partition_fuse_pass"],
max_duration=250,
min_success_num=50)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册