未验证 提交 9e08633c 编写于 作者: W Wang Bojun 提交者: GitHub

Layernorm shift partition enhance (#46816)

* first version of ln_s_p with s>0

* refine and UT

* pass opt draft

* pass opt

* code refine

* code-style

* bug fix

* fix ci test

* code style
上级 f0af2708
......@@ -3553,8 +3553,20 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() {
});
auto reshape1_out = pattern->NewNode(reshape1_out_repr())
->AsIntermediate()
->assert_is_op_input("reshape2", "X")
->assert_is_op_output("reshape2", "Out");
PDNode *roll1_op = nullptr;
PDNode *roll1_out = nullptr;
if (!with_roll_) {
reshape1_out->assert_is_op_input("reshape2", "X");
} else {
reshape1_out->assert_is_op_input("roll", "X");
roll1_op = pattern->NewNode(roll1_op_repr())->assert_is_op("roll");
roll1_out = pattern->NewNode(roll1_out_repr())
->AsIntermediate()
->assert_is_op_output("roll", "Out")
->assert_is_op_input("reshape2", "X");
}
auto reshape2_op =
pattern->NewNode(reshape2_op_repr())
->assert_is_op("reshape2")
......@@ -3564,6 +3576,7 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() {
node->Op()->GetAttr("shape"))
.size() == 6);
});
auto reshape2_out = pattern->NewNode(reshape2_out_repr())
->AsIntermediate()
->assert_is_op_input("transpose2", "X")
......@@ -3612,7 +3625,12 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() {
layer_norm_op->LinksFrom({layer_norm_in, layer_norm_bias, layer_norm_scale})
.LinksTo({layer_norm_out});
reshape1_op->LinksFrom({layer_norm_out}).LinksTo({reshape1_out});
reshape2_op->LinksFrom({reshape1_out}).LinksTo({reshape2_out});
if (!with_roll_) {
reshape2_op->LinksFrom({reshape1_out}).LinksTo({reshape2_out});
} else {
roll1_op->LinksFrom({reshape1_out}).LinksTo({roll1_out});
reshape2_op->LinksFrom({roll1_out}).LinksTo({reshape2_out});
}
transpose_op->LinksFrom({reshape2_out}).LinksTo({transpose_out});
reshape3_op->LinksFrom({transpose_out}).LinksTo({reshape3_out});
reshape4_op->LinksFrom({reshape3_out}).LinksTo({reshape4_out});
......
......@@ -1914,15 +1914,17 @@ struct LayerNorm : public PatternBase {
//
// \brief Pattern looking for subgraph representing layernorm_shift_partition
// operation.
// operation with shift_size = 0.
//
struct LayernormShiftPartitionPattern : public PatternBase {
LayernormShiftPartitionPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "layernorm_shift_partition") {}
const std::string& name_scope,
bool with_roll)
: PatternBase(pattern, name_scope, "layernorm_shift_partition"),
with_roll_(with_roll) {}
PDNode* operator()();
bool with_roll_;
PATTERN_DECL_NODE(layer_norm_in);
PATTERN_DECL_NODE(layer_norm_op);
PATTERN_DECL_NODE(layer_norm_bias);
......@@ -1930,6 +1932,10 @@ struct LayernormShiftPartitionPattern : public PatternBase {
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(reshape1_op);
PATTERN_DECL_NODE(reshape1_out);
// optional op roll
PATTERN_DECL_NODE(roll1_op);
PATTERN_DECL_NODE(roll1_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
PATTERN_DECL_NODE(transpose_op);
......
......@@ -16,6 +16,7 @@
#include <cmath>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"
......@@ -85,22 +86,33 @@ LayerNormShiftPartitionFusePass::LayerNormShiftPartitionFusePass() {
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("roll"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int64_t>>()
.End()
.AddAttr("shifts")
.IsType<std::vector<int64_t>>()
.End();
}
void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
int LayerNormShiftPartitionFusePass::ApplyPattern(ir::Graph* graph,
bool with_roll) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::InvalidArgument(
"The input graph of LayerNormShiftPartitionFusePass should not be "
"nullptr."));
FusePassBase::Init(scope_name_, graph);
GraphPatternDetector gpd;
patterns::LayernormShiftPartitionPattern shift_patition_pattern(
gpd.mutable_pattern(), scope_name_);
gpd.mutable_pattern(), scope_name_, with_roll);
shift_patition_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -108,8 +120,13 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
LOG(WARNING) << "layernorm_shift_partition_fuse in op compat failed.";
return;
}
VLOG(4) << "layernorm_shift_partition_fuse pass";
if (with_roll) {
VLOG(4)
<< "layernorm_shift_partition_fuse pass, shift_size>0, with roll op";
} else {
VLOG(4) << "layernorm_shift_partition_fuse pass, shift_size=0, without "
"roll op";
}
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_in, layer_norm_in, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
......@@ -123,6 +140,15 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(reshape1_op, reshape1_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape1_out, reshape1_out, shift_patition_pattern);
Node* roll1_op = nullptr;
Node* roll1_out = nullptr;
if (with_roll) {
GET_IR_NODE_FROM_SUBGRAPH(tmp_roll1_op, roll1_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
tmp_roll1_out, roll1_out, shift_patition_pattern);
roll1_op = tmp_roll1_op;
roll1_out = tmp_roll1_out;
}
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape2_out, reshape2_out, shift_patition_pattern);
......@@ -136,6 +162,21 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(reshape4_op, reshape4_op, shift_patition_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
reshape4_out, reshape4_out, shift_patition_pattern);
std::unordered_set<const Node*> del_node_set = {layer_norm_op,
layer_norm_out,
reshape1_op,
reshape1_out,
reshape2_op,
reshape2_out,
transpose_op,
transpose_out,
reshape3_op,
reshape3_out,
reshape4_op};
if (with_roll) {
del_node_set.insert(roll1_op);
del_node_set.insert(roll1_out);
}
std::vector<int> shape_atr1 =
PADDLE_GET_CONST(std::vector<int>, reshape1_op->Op()->GetAttr("shape"));
......@@ -165,7 +206,20 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
if (window_size < 0 || input_resolution < 0) {
return;
}
int shift_size = 0;
if (with_roll) {
std::vector<int64_t> roll_axis = PADDLE_GET_CONST(
std::vector<int64_t>, roll1_op->Op()->GetAttr("axis"));
std::vector<int64_t> roll_shifts = PADDLE_GET_CONST(
std::vector<int64_t>, roll1_op->Op()->GetAttr("shifts"));
if (roll_axis.size() != 2 || roll_axis[0] != 1 || roll_axis[1] != 2) {
return;
}
if (roll_shifts.size() != 2 || roll_shifts[0] != roll_shifts[1]) {
return;
}
shift_size = static_cast<int>(-roll_shifts[0]);
}
OpDesc new_op_desc;
new_op_desc.SetType("layernorm_shift_partition");
new_op_desc.SetInput("X", {layer_norm_in->Name()});
......@@ -176,6 +230,7 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
new_op_desc.SetAttr("begin_norm_axis",
layer_norm_op->Op()->GetAttr("begin_norm_axis"));
new_op_desc.SetAttr("window_size", window_size);
new_op_desc.SetAttr("shift_size", shift_size);
new_op_desc.SetAttr("input_resolution", input_resolution);
new_op_desc.Flush();
......@@ -185,22 +240,19 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
IR_NODE_LINK_TO(layer_norm_bias, layernorm_shift_partition);
IR_NODE_LINK_TO(layer_norm_scale, layernorm_shift_partition);
IR_NODE_LINK_TO(layernorm_shift_partition, reshape4_out);
GraphSafeRemoveNodes(graph,
{layer_norm_op,
layer_norm_out,
reshape1_op,
reshape1_out,
reshape2_op,
reshape2_out,
transpose_op,
transpose_out,
reshape3_op,
reshape3_out,
reshape4_op});
GraphSafeRemoveNodes(graph, del_node_set);
++found_count;
};
gpd(graph, handler);
return found_count;
}
void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
found_count += ApplyPattern(graph, true);
found_count += ApplyPattern(graph, false);
AddStatis(found_count);
}
......
......@@ -37,6 +37,26 @@ namespace ir {
// reshape2
// |
// other_op
//
// or
//
// |
// layer_norm
// |
// reshape2
// |
// roll
// |
// reshape2 |
// | fuse layernorm_shift_patition
// transpose2 -> |
// | other_op
// reshape2
// |
// reshape2
// |
// other_op
class LayerNormShiftPartitionFusePass : public FusePassBase {
public:
LayerNormShiftPartitionFusePass();
......@@ -44,6 +64,7 @@ class LayerNormShiftPartitionFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph *graph) const override;
int ApplyPattern(ir::Graph *graph, bool with_roll) const;
private:
const std::string scope_name_{"layernorm_shift_partition_fuse"};
......
......@@ -109,9 +109,9 @@ const std::vector<std::string> kTRTSubgraphPasses({
"vit_attention_fuse_pass", //
"trt_skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", //
"preln_residual_bias_fuse_pass", //
"layernorm_shift_partition_fuse_pass", //
// "set_transformer_input_convert_pass", //
"preln_residual_bias_fuse_pass", //
// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
"trt_squeeze2_matmul_fuse_pass", //
......
......@@ -40,11 +40,12 @@ class LayerNormShiftPartitionOpConverter : public OpConverter {
: 1e-5f;
const int window_size =
PADDLE_GET_CONST(int, op_desc.GetAttr("window_size"));
const int shift_size = PADDLE_GET_CONST(int, op_desc.GetAttr("shift_size"));
const int input_resolution =
PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution"));
// int shift_size = window_size / 2;
// shift_size = (input_resolution <= window_size) ? 0 : shift_size;
int shift_size = 0;
// int shift_size = 0;
PADDLE_ENFORCE_NOT_NULL(
Bias_v,
......
......@@ -14,6 +14,7 @@
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import math
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
......@@ -219,5 +220,215 @@ class TestLayernormShiftPartitionPass(PassAutoScanTest):
min_success_num=50)
class TestLayernormShiftPartition2Pass(PassAutoScanTest):
"""
|
layer_norm
|
reshape2
|
roll
|
reshape2
|
transpose2
|
reshape2
|
reshape2
|
"""
def sample_predictor_configs(self, program_config):
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
config.set_trt_dynamic_shape_info({
"input_data": [1, 9, 96],
}, {
"input_data": [4, 3136, 768],
}, {
"input_data": [1, 784, 384],
})
yield config, ['layernorm_shift_partition'], (1e-5, 1e-5)
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=4,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False)
config.set_trt_dynamic_shape_info({
"input_data": [1, 9, 96],
}, {
"input_data": [4, 3136, 768],
}, {
"input_data": [1, 784, 384],
})
yield config, ['layernorm_shift_partition'], (1e-3, 1e-3)
def sample_program_config(self, draw):
axis = [0, 1, 3, 2, 4, 5]
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
# begin_norm_axis has to be 2
begin_norm_axis = 2
batch_size = draw(st.integers(min_value=1, max_value=4))
window_size = draw(st.sampled_from([3, 5, 7]))
move_shape = draw(st.integers(min_value=1, max_value=8))
dim = draw(st.sampled_from([96, 192, 384, 768]))
def generate_input(attrs):
return np.random.random(
[attrs[1]["batch_size"],
*attrs[1]["input_dim"]]).astype(np.float32)
def generate_weight(attrs):
return np.random.random(attrs[1]['input_dim'][-1]).astype(
np.float32)
attrs = [{
'begin_norm_axis': begin_norm_axis,
'epsilon': epsilon,
}, {
'batch_size': batch_size,
'input_dim': [(window_size * move_shape)**2, dim],
}, {
'axis': axis,
'input_resolution': window_size * move_shape,
'move_shape': move_shape,
'window_size': window_size,
}]
layer_norm_op = OpConfig(type="layer_norm",
inputs={
"X": ["input_data"],
"Bias": ["layer_norm_bias"],
"Scale": ["layer_norm_scale"]
},
outputs={
"Y": ["layer_norm_output1"],
"Mean": ["layer_norm_output2"],
"Variance": ["layer_norm_output3"]
},
attrs={
"begin_norm_axis":
attrs[0]["begin_norm_axis"],
"epsilon": attrs[0]["epsilon"],
})
reshape_op2 = OpConfig(type="reshape2",
inputs={
"X": ["layer_norm_output1"],
},
outputs={
"Out": ["reshape_output2"],
"XShape": ["reshape_output2_xshape"],
},
attrs={
'shape': [
-1, attrs[2]["input_resolution"],
attrs[2]["input_resolution"],
attrs[1]["input_dim"][-1]
]
})
roll_op1 = OpConfig(type="roll",
inputs={"X": ["reshape_output2"]},
outputs={"Out": ["roll_output1"]},
attrs={
"axis": [1, 2],
"shifts": [
-math.floor(
(attrs[2]["window_size"]) / 2.0),
-math.floor((attrs[2]["window_size"]) / 2.0)
]
})
reshape_op3 = OpConfig(type="reshape2",
inputs={
"X": ["roll_output1"],
},
outputs={
"Out": ["reshape_output3"],
"XShape": ["reshape_output3_xshape"],
},
attrs={
'shape': [
-1, attrs[2]["move_shape"],
attrs[2]["window_size"],
attrs[2]["move_shape"],
attrs[2]["window_size"],
attrs[1]["input_dim"][-1]
]
})
transpose_op4 = OpConfig(type='transpose2',
inputs={
"X": ["reshape_output3"],
},
outputs={"Out": ["transpose_output4"]},
attrs={"axis": attrs[2]['axis']})
reshape_op5 = OpConfig(type="reshape2",
inputs={
"X": ["transpose_output4"],
},
outputs={
"Out": ["reshape_output5"],
"XShape": ["reshape_output5_xshape"],
},
attrs={
'shape': [
-1, attrs[2]["window_size"],
attrs[2]["window_size"],
attrs[1]["input_dim"][-1]
]
})
reshape_op6 = OpConfig(
type="reshape2",
inputs={
"X": ["reshape_output5"],
},
outputs={
"Out": ["reshape_output6"],
"XShape": ["reshape_output6_xshape"],
},
attrs={
'shape':
[-1, attrs[2]["window_size"]**2, attrs[1]["input_dim"][-1]]
})
program_config = ProgramConfig(
ops=[
layer_norm_op, reshape_op2, roll_op1, reshape_op3,
transpose_op4, reshape_op5, reshape_op6
],
weights={
"layer_norm_bias":
TensorConfig(data_gen=partial(generate_weight, attrs)),
"layer_norm_scale":
TensorConfig(data_gen=partial(generate_weight, attrs))
},
inputs={
"input_data":
TensorConfig(data_gen=partial(generate_input, attrs)),
},
outputs=["reshape_output6"])
return program_config
def test(self):
self.run_and_statis(quant=False,
max_examples=50,
passes=["layernorm_shift_partition_fuse_pass"],
max_duration=250,
min_success_num=50)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册