From 9ed1454ad0d376080c8fe545f99718124156fa9c Mon Sep 17 00:00:00 2001 From: Wang Bojun <105858416+wwbitejotunn@users.noreply.github.com> Date: Thu, 20 Oct 2022 14:30:28 +0800 Subject: [PATCH] [Cherry-pick] layernorm shift partation enhance (#47086) * Enhance the layernorm shift partation fuse op when shift size > 0 (roll shifting) * fix cherry-pick test --- .../framework/ir/graph_pattern_detector.cc | 22 +- .../framework/ir/graph_pattern_detector.h | 12 +- .../ir/layernorm_shift_partition_fuse_pass.cc | 92 ++++++-- .../ir/layernorm_shift_partition_fuse_pass.h | 21 ++ .../convert/layernorm_shift_partition_op.cc | 4 +- .../test_layernorm_shift_partition_pass.py | 211 ++++++++++++++++++ 6 files changed, 334 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 0d63ce2121..fc68cb514c 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3535,8 +3535,20 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() { }); auto reshape1_out = pattern->NewNode(reshape1_out_repr()) ->AsIntermediate() - ->assert_is_op_input("reshape2", "X") ->assert_is_op_output("reshape2", "Out"); + PDNode *roll1_op = nullptr; + PDNode *roll1_out = nullptr; + + if (!with_roll_) { + reshape1_out->assert_is_op_input("reshape2", "X"); + } else { + reshape1_out->assert_is_op_input("roll", "X"); + roll1_op = pattern->NewNode(roll1_op_repr())->assert_is_op("roll"); + roll1_out = pattern->NewNode(roll1_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("roll", "Out") + ->assert_is_op_input("reshape2", "X"); + } auto reshape2_op = pattern->NewNode(reshape2_op_repr()) ->assert_is_op("reshape2") @@ -3546,6 +3558,7 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() { node->Op()->GetAttr("shape")) .size() == 6); }); + auto reshape2_out = pattern->NewNode(reshape2_out_repr()) ->AsIntermediate() ->assert_is_op_input("transpose2", "X") @@ -3594,7 +3607,12 @@ PDNode *patterns::LayernormShiftPartitionPattern::operator()() { layer_norm_op->LinksFrom({layer_norm_in, layer_norm_bias, layer_norm_scale}) .LinksTo({layer_norm_out}); reshape1_op->LinksFrom({layer_norm_out}).LinksTo({reshape1_out}); - reshape2_op->LinksFrom({reshape1_out}).LinksTo({reshape2_out}); + if (!with_roll_) { + reshape2_op->LinksFrom({reshape1_out}).LinksTo({reshape2_out}); + } else { + roll1_op->LinksFrom({reshape1_out}).LinksTo({roll1_out}); + reshape2_op->LinksFrom({roll1_out}).LinksTo({reshape2_out}); + } transpose_op->LinksFrom({reshape2_out}).LinksTo({transpose_out}); reshape3_op->LinksFrom({transpose_out}).LinksTo({reshape3_out}); reshape4_op->LinksFrom({reshape3_out}).LinksTo({reshape4_out}); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index b2eb740b9a..27bb69b050 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1917,11 +1917,13 @@ struct LayerNorm : public PatternBase { // struct LayernormShiftPartitionPattern : public PatternBase { LayernormShiftPartitionPattern(PDPattern* pattern, - const std::string& name_scope) - : PatternBase(pattern, name_scope, "layernorm_shift_partition") {} + const std::string& name_scope, + bool with_roll) + : PatternBase(pattern, name_scope, "layernorm_shift_partition"), + with_roll_(with_roll) {} PDNode* operator()(); - + bool with_roll_; PATTERN_DECL_NODE(layer_norm_in); PATTERN_DECL_NODE(layer_norm_op); PATTERN_DECL_NODE(layer_norm_bias); @@ -1929,6 +1931,10 @@ struct LayernormShiftPartitionPattern : public PatternBase { PATTERN_DECL_NODE(layer_norm_out); PATTERN_DECL_NODE(reshape1_op); PATTERN_DECL_NODE(reshape1_out); + // optional op roll + PATTERN_DECL_NODE(roll1_op); + PATTERN_DECL_NODE(roll1_out); + PATTERN_DECL_NODE(reshape2_op); PATTERN_DECL_NODE(reshape2_out); PATTERN_DECL_NODE(transpose_op); diff --git a/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.cc b/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.cc index 9353f4b3ef..dbe990f636 100644 --- a/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.cc +++ b/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.cc @@ -16,6 +16,7 @@ #include #include +#include #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_proto_maker.h" @@ -85,22 +86,33 @@ LayerNormShiftPartitionFusePass::LayerNormShiftPartitionFusePass() { .AddAttr("axis") .IsType>() .End(); + AddOpCompat(OpCompat("roll")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End() + .AddAttr("shifts") + .IsType>() + .End(); } -void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { +int LayerNormShiftPartitionFusePass::ApplyPattern(ir::Graph* graph, + bool with_roll) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument( "The input graph of LayerNormShiftPartitionFusePass should not be " "nullptr.")); - FusePassBase::Init(scope_name_, graph); - GraphPatternDetector gpd; patterns::LayernormShiftPartitionPattern shift_patition_pattern( - gpd.mutable_pattern(), scope_name_); + gpd.mutable_pattern(), scope_name_, with_roll); shift_patition_pattern(); - int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -108,8 +120,13 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { LOG(WARNING) << "layernorm_shift_partition_fuse in op compat failed."; return; } - - VLOG(4) << "layernorm_shift_partition_fuse pass"; + if (with_roll) { + VLOG(4) + << "layernorm_shift_partition_fuse pass, shift_size>0, with roll op"; + } else { + VLOG(4) << "layernorm_shift_partition_fuse pass, shift_size=0, without " + "roll op"; + } GET_IR_NODE_FROM_SUBGRAPH( layer_norm_in, layer_norm_in, shift_patition_pattern); GET_IR_NODE_FROM_SUBGRAPH( @@ -123,6 +140,15 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(reshape1_op, reshape1_op, shift_patition_pattern); GET_IR_NODE_FROM_SUBGRAPH( reshape1_out, reshape1_out, shift_patition_pattern); + Node* roll1_op = nullptr; + Node* roll1_out = nullptr; + if (with_roll) { + GET_IR_NODE_FROM_SUBGRAPH(tmp_roll1_op, roll1_op, shift_patition_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + tmp_roll1_out, roll1_out, shift_patition_pattern); + roll1_op = tmp_roll1_op; + roll1_out = tmp_roll1_out; + } GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, shift_patition_pattern); GET_IR_NODE_FROM_SUBGRAPH( reshape2_out, reshape2_out, shift_patition_pattern); @@ -136,6 +162,21 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(reshape4_op, reshape4_op, shift_patition_pattern); GET_IR_NODE_FROM_SUBGRAPH( reshape4_out, reshape4_out, shift_patition_pattern); + std::unordered_set del_node_set = {layer_norm_op, + layer_norm_out, + reshape1_op, + reshape1_out, + reshape2_op, + reshape2_out, + transpose_op, + transpose_out, + reshape3_op, + reshape3_out, + reshape4_op}; + if (with_roll) { + del_node_set.insert(roll1_op); + del_node_set.insert(roll1_out); + } std::vector shape_atr1 = PADDLE_GET_CONST(std::vector, reshape1_op->Op()->GetAttr("shape")); @@ -165,7 +206,20 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { if (window_size < 0 || input_resolution < 0) { return; } - + int shift_size = 0; + if (with_roll) { + std::vector roll_axis = PADDLE_GET_CONST( + std::vector, roll1_op->Op()->GetAttr("axis")); + std::vector roll_shifts = PADDLE_GET_CONST( + std::vector, roll1_op->Op()->GetAttr("shifts")); + if (roll_axis.size() != 2 || roll_axis[0] != 1 || roll_axis[1] != 2) { + return; + } + if (roll_shifts.size() != 2 || roll_shifts[0] != roll_shifts[1]) { + return; + } + shift_size = static_cast(-roll_shifts[0]); + } OpDesc new_op_desc; new_op_desc.SetType("layernorm_shift_partition"); new_op_desc.SetInput("X", {layer_norm_in->Name()}); @@ -176,6 +230,7 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { new_op_desc.SetAttr("begin_norm_axis", layer_norm_op->Op()->GetAttr("begin_norm_axis")); new_op_desc.SetAttr("window_size", window_size); + new_op_desc.SetAttr("shift_size", shift_size); new_op_desc.SetAttr("input_resolution", input_resolution); new_op_desc.Flush(); @@ -185,22 +240,19 @@ void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { IR_NODE_LINK_TO(layer_norm_bias, layernorm_shift_partition); IR_NODE_LINK_TO(layer_norm_scale, layernorm_shift_partition); IR_NODE_LINK_TO(layernorm_shift_partition, reshape4_out); - GraphSafeRemoveNodes(graph, - {layer_norm_op, - layer_norm_out, - reshape1_op, - reshape1_out, - reshape2_op, - reshape2_out, - transpose_op, - transpose_out, - reshape3_op, - reshape3_out, - reshape4_op}); + GraphSafeRemoveNodes(graph, del_node_set); ++found_count; }; gpd(graph, handler); + + return found_count; +} + +void LayerNormShiftPartitionFusePass::ApplyImpl(ir::Graph* graph) const { + int found_count = 0; + found_count += ApplyPattern(graph, true); + found_count += ApplyPattern(graph, false); AddStatis(found_count); } diff --git a/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.h b/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.h index 7c3d435ef4..6bbcd64e30 100644 --- a/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.h +++ b/paddle/fluid/framework/ir/layernorm_shift_partition_fuse_pass.h @@ -37,6 +37,26 @@ namespace ir { // reshape2 // | // other_op +// +// or +// +// | +// layer_norm +// | +// reshape2 +// | +// roll +// | +// reshape2 | +// | fuse layernorm_shift_patition +// transpose2 -> | +// | other_op +// reshape2 +// | +// reshape2 +// | +// other_op + class LayerNormShiftPartitionFusePass : public FusePassBase { public: LayerNormShiftPartitionFusePass(); @@ -44,6 +64,7 @@ class LayerNormShiftPartitionFusePass : public FusePassBase { protected: void ApplyImpl(ir::Graph *graph) const override; + int ApplyPattern(ir::Graph *graph, bool with_roll) const; private: const std::string scope_name_{"layernorm_shift_partition_fuse"}; diff --git a/paddle/fluid/inference/tensorrt/convert/layernorm_shift_partition_op.cc b/paddle/fluid/inference/tensorrt/convert/layernorm_shift_partition_op.cc index 15f2663ce5..147c9a9731 100644 --- a/paddle/fluid/inference/tensorrt/convert/layernorm_shift_partition_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/layernorm_shift_partition_op.cc @@ -40,11 +40,9 @@ class LayerNormShiftPartitionOpConverter : public OpConverter { : 1e-5f; const int window_size = PADDLE_GET_CONST(int, op_desc.GetAttr("window_size")); + const int shift_size = PADDLE_GET_CONST(int, op_desc.GetAttr("shift_size")); const int input_resolution = PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution")); - // int shift_size = window_size / 2; - // shift_size = (input_resolution <= window_size) ? 0 : shift_size; - int shift_size = 0; PADDLE_ENFORCE_NOT_NULL( Bias_v, diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_layernorm_shift_partition_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_layernorm_shift_partition_pass.py index a4d74611fe..d2a93adc2a 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_layernorm_shift_partition_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_layernorm_shift_partition_pass.py @@ -15,6 +15,7 @@ from auto_scan_test import PassAutoScanTest, IgnoreReasons from program_config import TensorConfig, ProgramConfig, OpConfig import numpy as np +import math import paddle.inference as paddle_infer from functools import partial from typing import Optional, List, Callable, Dict, Any, Set @@ -222,5 +223,215 @@ class TestLayernormShiftPartitionPass(PassAutoScanTest): min_success_num=50) +class TestLayernormShiftPartition2Pass(PassAutoScanTest): + """ + | + layer_norm + | + reshape2 + | + roll + | + reshape2 + | + transpose2 + | + reshape2 + | + reshape2 + | + """ + + def sample_predictor_configs(self, program_config): + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + config.set_trt_dynamic_shape_info({ + "input_data": [1, 9, 96], + }, { + "input_data": [4, 3136, 768], + }, { + "input_data": [1, 784, 384], + }) + yield config, ['layernorm_shift_partition'], (1e-5, 1e-5) + + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=4, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Half, + use_static=False, + use_calib_mode=False) + config.set_trt_dynamic_shape_info({ + "input_data": [1, 9, 96], + }, { + "input_data": [4, 3136, 768], + }, { + "input_data": [1, 784, 384], + }) + yield config, ['layernorm_shift_partition'], (1e-3, 1e-3) + + def sample_program_config(self, draw): + axis = [0, 1, 3, 2, 4, 5] + epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001)) + # begin_norm_axis has to be 2 + begin_norm_axis = 2 + batch_size = draw(st.integers(min_value=1, max_value=4)) + + window_size = draw(st.sampled_from([3, 5, 7])) + move_shape = draw(st.integers(min_value=1, max_value=8)) + dim = draw(st.sampled_from([96, 192, 384, 768])) + + def generate_input(attrs): + return np.random.random( + [attrs[1]["batch_size"], + *attrs[1]["input_dim"]]).astype(np.float32) + + def generate_weight(attrs): + return np.random.random(attrs[1]['input_dim'][-1]).astype( + np.float32) + + attrs = [{ + 'begin_norm_axis': begin_norm_axis, + 'epsilon': epsilon, + }, { + 'batch_size': batch_size, + 'input_dim': [(window_size * move_shape)**2, dim], + }, { + 'axis': axis, + 'input_resolution': window_size * move_shape, + 'move_shape': move_shape, + 'window_size': window_size, + }] + + layer_norm_op = OpConfig(type="layer_norm", + inputs={ + "X": ["input_data"], + "Bias": ["layer_norm_bias"], + "Scale": ["layer_norm_scale"] + }, + outputs={ + "Y": ["layer_norm_output1"], + "Mean": ["layer_norm_output2"], + "Variance": ["layer_norm_output3"] + }, + attrs={ + "begin_norm_axis": + attrs[0]["begin_norm_axis"], + "epsilon": attrs[0]["epsilon"], + }) + reshape_op2 = OpConfig(type="reshape2", + inputs={ + "X": ["layer_norm_output1"], + }, + outputs={ + "Out": ["reshape_output2"], + "XShape": ["reshape_output2_xshape"], + }, + attrs={ + 'shape': [ + -1, attrs[2]["input_resolution"], + attrs[2]["input_resolution"], + attrs[1]["input_dim"][-1] + ] + }) + roll_op1 = OpConfig(type="roll", + inputs={"X": ["reshape_output2"]}, + outputs={"Out": ["roll_output1"]}, + attrs={ + "axis": [1, 2], + "shifts": [ + -math.floor( + (attrs[2]["window_size"]) / 2.0), + -math.floor((attrs[2]["window_size"]) / 2.0) + ] + }) + reshape_op3 = OpConfig(type="reshape2", + inputs={ + "X": ["roll_output1"], + }, + outputs={ + "Out": ["reshape_output3"], + "XShape": ["reshape_output3_xshape"], + }, + attrs={ + 'shape': [ + -1, attrs[2]["move_shape"], + attrs[2]["window_size"], + attrs[2]["move_shape"], + attrs[2]["window_size"], + attrs[1]["input_dim"][-1] + ] + }) + transpose_op4 = OpConfig(type='transpose2', + inputs={ + "X": ["reshape_output3"], + }, + outputs={"Out": ["transpose_output4"]}, + attrs={"axis": attrs[2]['axis']}) + reshape_op5 = OpConfig(type="reshape2", + inputs={ + "X": ["transpose_output4"], + }, + outputs={ + "Out": ["reshape_output5"], + "XShape": ["reshape_output5_xshape"], + }, + attrs={ + 'shape': [ + -1, attrs[2]["window_size"], + attrs[2]["window_size"], + attrs[1]["input_dim"][-1] + ] + }) + reshape_op6 = OpConfig( + type="reshape2", + inputs={ + "X": ["reshape_output5"], + }, + outputs={ + "Out": ["reshape_output6"], + "XShape": ["reshape_output6_xshape"], + }, + attrs={ + 'shape': + [-1, attrs[2]["window_size"]**2, attrs[1]["input_dim"][-1]] + }) + + program_config = ProgramConfig( + ops=[ + layer_norm_op, reshape_op2, roll_op1, reshape_op3, + transpose_op4, reshape_op5, reshape_op6 + ], + weights={ + "layer_norm_bias": + TensorConfig(data_gen=partial(generate_weight, attrs)), + "layer_norm_scale": + TensorConfig(data_gen=partial(generate_weight, attrs)) + }, + inputs={ + "input_data": + TensorConfig(data_gen=partial(generate_input, attrs)), + }, + outputs=["reshape_output6"]) + + return program_config + + def test(self): + self.run_and_statis(quant=False, + max_examples=50, + passes=["layernorm_shift_partition_fuse_pass"], + max_duration=250, + min_success_num=50) + + if __name__ == "__main__": unittest.main() -- GitLab