From adcb0039153e0ea6f381b0de523b50ca4cd450e9 Mon Sep 17 00:00:00 2001 From: wenbin Date: Thu, 12 Jan 2023 13:19:03 +0800 Subject: [PATCH] more preln_gn patterns (#49728) * compile fix * fix compile * compile fix * add more preln --- .../preln_elementwise_groupnorm_act_pass.cc | 61 +++++--- .../ir/preln_elementwise_groupnorm_act_pass.h | 4 +- .../convert/preln_groupnorm_act_op.cc | 2 + .../plugin/preln_groupnorm_act_op_plugin.cu | 2 +- .../plugin/preln_groupnorm_act_op_plugin.h | 8 +- .../test_preln_groupnorm_act_fuse_pass.py | 136 ++++++++++++++++++ 6 files changed, 189 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.cc b/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.cc index 478c315b9e..7cbb5c169f 100644 --- a/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.cc +++ b/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.cc @@ -35,7 +35,7 @@ struct PrelnGroupNormAct : public PatternBase { PrelnGroupNormAct(PDPattern *pattern, const std::string &name_scope) : PatternBase(pattern, name_scope, "preln_groupnorm_act") {} - void operator()(PDNode *x, PDNode *y); + void operator()(PDNode *x, PDNode *y, bool with_act); // declare operator node's name PATTERN_DECL_NODE(elementwise); PATTERN_DECL_NODE(group_norm); @@ -49,7 +49,7 @@ struct PrelnGroupNormAct : public PatternBase { PATTERN_DECL_NODE(act_out); }; -void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y) { +void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y, bool with_act) { auto *elementwise = pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); @@ -74,26 +74,28 @@ void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y) { auto *group_norm_out_var = pattern->NewNode(group_norm_out_repr()) ->AsOutput() - ->assert_is_op_output("group_norm", "Y") - ->assert_is_op_input("silu", "X"); + ->assert_is_op_output("group_norm", "Y"); // Add links for group_norm op. group_norm ->LinksFrom( {elementwise_out_var, group_norm_bias_var, group_norm_scale_var}) .LinksTo({group_norm_out_var}); - - auto *act = pattern->NewNode(act_repr())->assert_is_op("silu"); - auto *act_out = pattern->NewNode(act_out_repr()) - ->AsOutput() - ->assert_is_op_output("silu", "Out"); - - act->LinksFrom({group_norm_out_var}).LinksTo({act_out}); + if (with_act) { + group_norm_out_var->assert_is_op_input("silu", "X"); + auto *act = pattern->NewNode(act_repr())->assert_is_op("silu"); + auto *act_out = pattern->NewNode(act_out_repr()) + ->AsOutput() + ->assert_is_op_output("silu", "Out"); + + act->LinksFrom({group_norm_out_var}).LinksTo({act_out}); + } } } // namespace patterns -int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { +int PrelnGroupNormActFusePass::ApplyAddGNPattern(ir::Graph *graph, + bool with_act) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init("preln_groupnorm_silu_fuse", graph); @@ -118,7 +120,7 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { patterns::PrelnGroupNormAct fused_pattern(gpd.mutable_pattern(), "preln_groupnorm_act_fuse"); - fused_pattern(x, y); + fused_pattern(x, y, with_act); auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *graph) { @@ -129,6 +131,9 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { VLOG(4) << "handle preln groupnorm act fuse"; + Node *act = nullptr; + Node *act_out = nullptr; + GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(group_norm, group_norm, fused_pattern); @@ -136,8 +141,12 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { GET_IR_NODE_FROM_SUBGRAPH( group_norm_scale, group_norm_scale, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(group_norm_out, group_norm_out, fused_pattern); - GET_IR_NODE_FROM_SUBGRAPH(act, act, fused_pattern); - GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fused_pattern); + if (with_act) { + GET_IR_NODE_FROM_SUBGRAPH(tmp_act, act, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(tmp_act_out, act_out, fused_pattern); + act = tmp_act; + act_out = tmp_act_out; + } if (!IsCompat(subgraph, graph)) { LOG(WARNING) << "preln groupnorm act pass in op compat failed."; @@ -150,8 +159,13 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { new_desc.SetType("preln_groupnorm_act"); new_desc.SetInput("X", {subgraph.at(x)->Name()}); new_desc.SetInput("Y", {subgraph.at(y)->Name()}); + new_desc.SetAttr("with_silu", with_act); new_desc.SetOutput("Out_0", {elementwise_out->Name()}); - new_desc.SetOutput("Out_1", {act_out->Name()}); + if (with_act) { + new_desc.SetOutput("Out_1", {act_out->Name()}); + } else { + new_desc.SetOutput("Out_1", {group_norm_out->Name()}); + } new_desc.RemoveOutput("Y"); new_desc.Flush(); @@ -159,15 +173,21 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { del_node_set.insert(elementwise); del_node_set.insert(group_norm); - del_node_set.insert(group_norm_out); - del_node_set.insert(act); + if (with_act) { + del_node_set.insert(act); + del_node_set.insert(group_norm_out); + } GraphSafeRemoveNodes(graph, del_node_set); IR_NODE_LINK_TO(subgraph.at(x), fused_node); IR_NODE_LINK_TO(subgraph.at(y), fused_node); IR_NODE_LINK_TO(group_norm_scale, fused_node); IR_NODE_LINK_TO(group_norm_bias, fused_node); - IR_NODE_LINK_TO(fused_node, act_out); + if (with_act) { + IR_NODE_LINK_TO(fused_node, act_out); + } else { + IR_NODE_LINK_TO(fused_node, group_norm_out); + } IR_NODE_LINK_TO(fused_node, elementwise_out); found_subgraph_count++; }; @@ -178,7 +198,8 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { void PrelnGroupNormActFusePass::ApplyImpl(ir::Graph *graph) const { FusePassBase::Init("preln_groupnorm_act_fuse_pass", graph); - int found_subgraph_count = ApplyGNSiluPattern(graph); + int found_subgraph_count = ApplyAddGNPattern(graph, true); + found_subgraph_count += ApplyAddGNPattern(graph, false); AddStatis(found_subgraph_count); } diff --git a/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h b/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h index 367f5101f4..59bc2b1026 100644 --- a/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h +++ b/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h @@ -25,7 +25,7 @@ namespace ir { // | | -> preln_gn_act // other op group_norm | | // | other op -// silu +// silu(optional) // | class Graph; @@ -88,7 +88,7 @@ class PrelnGroupNormActFusePass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const override; - int ApplyGNSiluPattern(ir::Graph* graph) const; + int ApplyAddGNPattern(ir::Graph* graph, bool with_act) const; }; } // namespace ir diff --git a/paddle/fluid/inference/tensorrt/convert/preln_groupnorm_act_op.cc b/paddle/fluid/inference/tensorrt/convert/preln_groupnorm_act_op.cc index 27d283f672..1c11562ac9 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_groupnorm_act_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_groupnorm_act_op.cc @@ -45,6 +45,7 @@ class PrelnGroupnormActOpConverter : public OpConverter { int groups = PADDLE_GET_CONST(int, op_desc.GetAttr("groups")); float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); + bool with_silu = PADDLE_GET_CONST(bool, op_desc.GetAttr("with_silu")); std::string scale_name = op_desc.Input("Scale").front(); std::string bias_name = op_desc.Input("Bias").front(); @@ -75,6 +76,7 @@ class PrelnGroupnormActOpConverter : public OpConverter { bias_weights.get().count, epsilon, groups, + with_silu, with_fp16); nvinfer1::ILayer* groupnorm_layer = engine_->AddDynamicPlugin(inputs.data(), 2, plugin); diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu index 35821159ae..a756a826bf 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu @@ -431,7 +431,7 @@ int PrelnGroupnormActPluginDynamic::enqueue( if (cPerBlock > input_desc[0].dims.d[1]) { cPerBlock = 8; } - params_.withSwish = true; + params_.withSwish = with_silu_; params_.dst = static_cast(outputs[1]); params_.eleOut = static_cast(outputs[0]); params_.srcX = static_cast(inputs[0]); diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h index 70f5769f1c..501372b9c3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h @@ -36,6 +36,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { const int bias_num, float eps, int groups, + bool with_silu, bool with_fp16, std::shared_ptr scale_gpu = nullptr, std::shared_ptr bias_gpu = nullptr) @@ -43,6 +44,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { bias_gpu_(bias_gpu), groups_(groups), eps_(eps), + with_silu_(with_silu), with_fp16_(with_fp16) { scale_.resize(scale_num); bias_.resize(bias_num); @@ -69,6 +71,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { DeserializeValue(&serialData, &serialLength, &bias_); DeserializeValue(&serialData, &serialLength, &eps_); DeserializeValue(&serialData, &serialLength, &groups_); + DeserializeValue(&serialData, &serialLength, &with_silu_); DeserializeValue(&serialData, &serialLength, &with_fp16_); { @@ -97,6 +100,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { bias_.size(), eps_, groups_, + with_silu_, with_fp16_, scale_gpu_, bias_gpu_); @@ -112,13 +116,14 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { size_t getSerializationSize() const TRT_NOEXCEPT override { return SerializedSize(scale_) + SerializedSize(bias_) + SerializedSize(eps_) + SerializedSize(groups_) + - SerializedSize(with_fp16_); + SerializedSize(with_silu_) + SerializedSize(with_fp16_); } void serialize(void* buffer) const TRT_NOEXCEPT override { SerializeValue(&buffer, scale_); SerializeValue(&buffer, bias_); SerializeValue(&buffer, eps_); SerializeValue(&buffer, groups_); + SerializeValue(&buffer, with_silu_); SerializeValue(&buffer, with_fp16_); } nvinfer1::DimsExprs getOutputDimensions( @@ -171,6 +176,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { GroupNormNHWCParams params_; int groups_; float eps_; + bool with_silu_; bool with_fp16_; }; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_preln_groupnorm_act_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_preln_groupnorm_act_fuse_pass.py index 4dd5633d2e..e3b5e24a9c 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_preln_groupnorm_act_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_preln_groupnorm_act_fuse_pass.py @@ -169,5 +169,141 @@ class TestElementGNActPass(PassAutoScanTest): ) +class TestElementGNNoActPass(PassAutoScanTest): + # + # | | | | + # other_op1 other_op2 other_op1 other_op2 + # | | fuse \ / + # elementwise_add -> preln_groupnorm_act + # | | | | + # other_op3 groupnorm other_op3 + # | + # + + def sample_predictor_configs(self, program_config): + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Half, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + { + "input_data_x": [1, 160, 1, 1], + "input_data_y": [1, 160, 1, 1], + }, + { + "input_data_x": [4, 1280, 64, 64], + "input_data_y": [4, 1280, 64, 64], + }, + { + "input_data_x": [1, 320, 32, 32], + "input_data_y": [1, 320, 32, 32], + }, + ) + yield config, ['preln_groupnorm_act'], (3e-3, 1e-3) + + def sample_program_config(self, draw): + axis = draw(st.sampled_from([0, -1])) + epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001)) + batch_size = draw(st.integers(min_value=1, max_value=4)) + + groups = draw(st.sampled_from([4, 8, 16, 32])) + hw = draw(st.sampled_from([1, 8, 16, 32])) + channel = draw(st.sampled_from([320, 1280])) + + def generate_input_x(attrs): + return np.random.random( + [attrs[1]["batch_size"], *attrs[1]["input_dim_x"]] + ).astype(np.float32) + + def generate_input_y(attrs): + return np.random.random( + [attrs[1]["batch_size"], *attrs[1]["input_dim_y"]] + ).astype(np.float32) + + def generate_weight(attrs): + return np.random.random(attrs[1]['input_dim_x'][0]).astype( + np.float32 + ) + + attrs = [ + { + 'axis': axis, + 'epsilon': epsilon, + 'groups': groups, + }, + { + 'batch_size': batch_size, + 'input_dim_x': [channel, hw, hw], + 'input_dim_y': [channel, hw, hw], + }, + ] + + elementwise_add_op = OpConfig( + type="elementwise_add", + inputs={"X": ["input_data_x"], "Y": ["input_data_y"]}, + outputs={"Out": ["ele_out"]}, + attrs={"axis": attrs[0]['axis']}, + ) + group_norm_op = OpConfig( + type="group_norm", + inputs={ + "X": ["ele_out"], + "Bias": ["group_norm_bias"], + "Scale": ["group_norm_scale"], + }, + outputs={ + "Y": ["group_norm_output1"], + "Mean": ["group_norm_output2"], + "Variance": ["group_norm_output3"], + }, + attrs={ + "data_layout": "NCHW", + "groups": attrs[0]["groups"], + "epsilon": attrs[0]["epsilon"], + }, + ) + + program_config = ProgramConfig( + ops=[ + elementwise_add_op, + group_norm_op, + ], + weights={ + "group_norm_bias": TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + "group_norm_scale": TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + }, + inputs={ + "input_data_x": TensorConfig( + data_gen=partial(generate_input_x, attrs) + ), + "input_data_y": TensorConfig( + data_gen=partial(generate_input_y, attrs) + ), + }, + outputs=["ele_out", "group_norm_output1"], + ) + + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=50, + passes=["preln_elementwise_groupnorm_act_pass"], + max_duration=250, + min_success_num=50, + ) + + if __name__ == "__main__": unittest.main() -- GitLab