From 111075a349054acb67d272450da4dc5f81ad61c8 Mon Sep 17 00:00:00 2001 From: wenbin Date: Tue, 31 Jan 2023 20:07:54 +0800 Subject: [PATCH] gn_silu (#49928) * gn_silu * add ut * set TIMEOUT * correct comments * comments * disable windows ut * rename parameter --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../fluid/framework/ir/groupnorm_act_pass.cc | 167 ++++++++++++++++++ .../fluid/framework/ir/groupnorm_act_pass.h | 81 +++++++++ .../inference/api/paddle_pass_builder.cc | 1 + .../tensorrt/convert/group_norm_op.cc | 6 + .../plugin/common/groupNormPluginCommon.h | 4 +- .../tensorrt/plugin/group_norm_op_plugin.cu | 8 +- .../tensorrt/plugin/group_norm_op_plugin.h | 8 +- .../plugin/preln_groupnorm_act_op_plugin.cu | 6 +- .../plugin/skip_groupnorm_act_op_plugin.cu | 6 +- .../unittests/ir/inference/CMakeLists.txt | 3 + .../test_groupnorm_act_pass_fuse_pass.py | 150 ++++++++++++++++ 12 files changed, 428 insertions(+), 13 deletions(-) create mode 100644 paddle/fluid/framework/ir/groupnorm_act_pass.cc create mode 100644 paddle/fluid/framework/ir/groupnorm_act_pass.h create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_groupnorm_act_pass_fuse_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index b387dc1d6c..23d5b0de24 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -144,6 +144,7 @@ if(WITH_TENSORRT) pass_library(trt_support_nhwc_pass inference) pass_library(elementwise_groupnorm_act_pass inference) pass_library(preln_elementwise_groupnorm_act_pass inference) + pass_library(groupnorm_act_pass inference) pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) endif() diff --git a/paddle/fluid/framework/ir/groupnorm_act_pass.cc b/paddle/fluid/framework/ir/groupnorm_act_pass.cc new file mode 100644 index 0000000000..397a743775 --- /dev/null +++ b/paddle/fluid/framework/ir/groupnorm_act_pass.cc @@ -0,0 +1,167 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/groupnorm_act_pass.h" + +#include + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +class Node; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct GroupNormAct : public PatternBase { + GroupNormAct(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "groupnorm_act") {} + + void operator()(PDNode *x); + // declare operator node's name + PATTERN_DECL_NODE(group_norm); + // declare variable node's name + PATTERN_DECL_NODE(elementwise_out); + + PATTERN_DECL_NODE(group_norm_bias); + PATTERN_DECL_NODE(group_norm_scale); + PATTERN_DECL_NODE(group_norm_out); + PATTERN_DECL_NODE(act); + PATTERN_DECL_NODE(act_out); +}; + +void GroupNormAct::operator()(PDNode *x) { + // Create nodes for group_norm op. + auto *group_norm = + pattern->NewNode(group_norm_repr())->assert_is_op("group_norm"); + auto *group_norm_bias_var = pattern->NewNode(group_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("group_norm", "Bias"); + + auto *group_norm_scale_var = pattern->NewNode(group_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("group_norm", "Scale"); + + auto *group_norm_out_var = pattern->NewNode(group_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("group_norm", "Y") + ->assert_is_op_input("silu", "X"); + + // Add links for group_norm op. + group_norm->LinksFrom({x, group_norm_bias_var, group_norm_scale_var}) + .LinksTo({group_norm_out_var}); + + auto *act = pattern->NewNode(act_repr())->assert_is_op("silu"); + auto *act_out = pattern->NewNode(act_out_repr()) + ->AsOutput() + ->assert_is_op_output("silu", "Out"); + + act->LinksFrom({group_norm_out_var}).LinksTo({act_out}); +} + +} // namespace patterns + +int GroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init("groupnorm_silu_fuse", graph); + + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + PDNode *x = nullptr; + + x = gpd.mutable_pattern() + ->NewNode("groupnorm_act_fuse/x") + ->AsInput() + ->assert_var_not_persistable() + ->assert_is_op_input("group_norm", "X"); + + patterns::GroupNormAct fused_pattern(gpd.mutable_pattern(), + "groupnorm_act_fuse"); + fused_pattern(x); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (subgraph.count(x) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + + VLOG(4) << "handle groupnorm act fuse"; + + GET_IR_NODE_FROM_SUBGRAPH(group_norm, group_norm, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(group_norm_bias, group_norm_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + group_norm_scale, group_norm_scale, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(group_norm_out, group_norm_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act, act, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fused_pattern); + + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "groupnorm act pass in op compat failed."; + return; + } + + std::unordered_set del_node_set; + // Create an skip_groupnorm_act op node + OpDesc new_desc(*group_norm->Op()); + new_desc.SetAttr("with_silu", true); + new_desc.SetOutput("Y", {act_out->Name()}); + new_desc.Flush(); + + auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + + del_node_set.insert(group_norm); + del_node_set.insert(group_norm_out); + del_node_set.insert(act); + GraphSafeRemoveNodes(graph, del_node_set); + + IR_NODE_LINK_TO(subgraph.at(x), fused_node); + IR_NODE_LINK_TO(group_norm_scale, fused_node); + IR_NODE_LINK_TO(group_norm_bias, fused_node); + IR_NODE_LINK_TO(fused_node, act_out); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +void GroupNormActFusePass::ApplyImpl(ir::Graph *graph) const { + FusePassBase::Init("groupnorm_act_fuse_pass", graph); + int found_subgraph_count = ApplyGNSiluPattern(graph); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(groupnorm_act_pass, paddle::framework::ir::GroupNormActFusePass); +REGISTER_PASS_CAPABILITY(groupnorm_act_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("silu", 0) + .EQ("group_norm", 0)); diff --git a/paddle/fluid/framework/ir/groupnorm_act_pass.h b/paddle/fluid/framework/ir/groupnorm_act_pass.h new file mode 100644 index 0000000000..16e4d332d2 --- /dev/null +++ b/paddle/fluid/framework/ir/groupnorm_act_pass.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +// +// | | +// group_norm group_norm +// | -> | +// silu +// | + +class Graph; + +class GroupNormActFusePass : public FusePassBase { + public: + GroupNormActFusePass() { + AddOpCompat(OpCompat("group_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(1.0f) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("data_layout") + .IsStringIn({"NCHW"}) + .End(); + AddOpCompat(OpCompat("silu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + } + + virtual ~GroupNormActFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + int ApplyGNSiluPattern(ir::Graph* graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 9f28343525..b5582518ea 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -136,6 +136,7 @@ const std::vector kTRTSubgraphPasses({ #else "elementwise_groupnorm_act_pass", // "preln_elementwise_groupnorm_act_pass", // + "groupnorm_act_pass", // #endif "tensorrt_subgraph_pass", // "conv_bn_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc index 2afc86dfc8..4384f7d2b3 100644 --- a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc @@ -46,6 +46,11 @@ class GroupNormOpConverter : public OpConverter { std::string scale_name = op_desc.Input("Scale").front(); std::string bias_name = op_desc.Input("Bias").front(); + bool with_silu = false; + if (op_desc.HasAttr("with_silu")) { + with_silu = PADDLE_GET_CONST(bool, op_desc.GetAttr("with_silu")); + } + // get the presistable var's data auto GetWeight = [&](const std::string& var_name, framework::DDim* dims) -> TensorRTEngine::Weight { @@ -77,6 +82,7 @@ class GroupNormOpConverter : public OpConverter { groups, mean_shape, variance_shape, + with_silu, with_fp16); nvinfer1::ILayer* groupnorm_layer = engine_->AddDynamicPlugin(&input_itensor, 1, plugin); diff --git a/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h b/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h index 81d507e866..915ee1b5e2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h +++ b/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h @@ -49,8 +49,8 @@ struct GroupNormNHWCParams { int32_t c; // The number of groups. int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; + // Do we apply the Silu activation function? + bool withSilu; // Precomputed values and parameters to control the execution of the kernels. diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu index 77c00d47d4..fc139a9734 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu @@ -247,8 +247,8 @@ __global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) { f2.x = gammaF2.x * f2.x + betaF2.x; f2.y = gammaF2.y * f2.y + betaF2.y; - // Apply Swish if needed. - if (params.withSwish) { + // Apply Silu if needed. + if (params.withSilu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -457,7 +457,7 @@ bool GroupNormPluginDynamic::supportsFormatCombination( if (pos == 0) { if (with_fp16_) { return ((in.type == nvinfer1::DataType::kHALF) && - (in.format == nvinfer1::PluginFormat::kLINEAR || + ((!with_silu_ && in.format == nvinfer1::PluginFormat::kLINEAR) || in.format == nvinfer1::PluginFormat::kHWC8)); } else { return (in.type == nvinfer1::DataType::kFLOAT) && @@ -624,7 +624,7 @@ int GroupNormPluginDynamic::enqueue( cPerBlock = 8; } - params_.withSwish = false; + params_.withSilu = with_silu_; params_.dst = static_cast(outputs[0]); params_.srcX = static_cast(inputs[0]); params_.gamma = scale_gpu_; diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h index 1fa505c077..3feb35e070 100644 --- a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h @@ -164,11 +164,13 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { int groups, std::vector mean_shape, std::vector variance_shape, + bool with_silu, bool with_fp16) : groups_(groups), eps_(eps), mean_shape_(mean_shape), variance_shape_(variance_shape), + with_silu_(with_silu), with_fp16_(with_fp16) { scale_.resize(scale_num); bias_.resize(bias_num); @@ -183,6 +185,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { DeserializeValue(&serialData, &serialLength, &groups_); DeserializeValue(&serialData, &serialLength, &mean_shape_); DeserializeValue(&serialData, &serialLength, &variance_shape_); + DeserializeValue(&serialData, &serialLength, &with_silu_); DeserializeValue(&serialData, &serialLength, &with_fp16_); } nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { @@ -194,6 +197,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { groups_, mean_shape_, variance_shape_, + with_silu_, with_fp16_); ptr->scale_gpu_ = scale_gpu_; ptr->bias_gpu_ = bias_gpu_; @@ -210,7 +214,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { return SerializedSize(scale_) + SerializedSize(bias_) + SerializedSize(eps_) + SerializedSize(groups_) + SerializedSize(mean_shape_) + SerializedSize(variance_shape_) + - SerializedSize(with_fp16_); + SerializedSize(with_silu_) + SerializedSize(with_fp16_); } void serialize(void* buffer) const TRT_NOEXCEPT override { SerializeValue(&buffer, scale_); @@ -219,6 +223,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { SerializeValue(&buffer, groups_); SerializeValue(&buffer, mean_shape_); SerializeValue(&buffer, variance_shape_); + SerializeValue(&buffer, with_silu_); SerializeValue(&buffer, with_fp16_); } nvinfer1::DimsExprs getOutputDimensions( @@ -277,6 +282,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { std::vector mean_shape_; std::vector variance_shape_; GroupNormNHWCParams params_; + bool with_silu_; bool with_fp16_; }; class GroupNormPluginDynamicCreator : public TensorRTPluginCreator { diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu index a756a826bf..d3ca36770a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu @@ -330,8 +330,8 @@ __global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { f2.x = gammaF2.x * f2.x + betaF2.x; f2.y = gammaF2.y * f2.y + betaF2.y; - // Apply Swish if needed. - if (params.withSwish) { + // Apply Silu if needed. + if (params.withSilu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -431,7 +431,7 @@ int PrelnGroupnormActPluginDynamic::enqueue( if (cPerBlock > input_desc[0].dims.d[1]) { cPerBlock = 8; } - params_.withSwish = with_silu_; + params_.withSilu = with_silu_; params_.dst = static_cast(outputs[1]); params_.eleOut = static_cast(outputs[0]); params_.srcX = static_cast(inputs[0]); diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu index adba932447..997205e918 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu @@ -340,8 +340,8 @@ __global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { f2.x = gammaF2.x * f2.x + betaF2.x; f2.y = gammaF2.y * f2.y + betaF2.y; - // Apply Swish if needed. - if (params.withSwish) { + // Apply Silu if needed. + if (params.withSilu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -439,7 +439,7 @@ int SkipGroupnormActPluginDynamic::enqueue( if (cPerBlock > input_desc[0].dims.d[1]) { cPerBlock = 8; } - params_.withSwish = true; + params_.withSilu = true; params_.dst = static_cast(outputs[0]); params_.srcX = static_cast(inputs[0]); params_.srcY = static_cast(inputs[1]); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index d456a86aa9..bdcf6ab951 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -39,6 +39,7 @@ if(WIN32) "test_preln_groupnorm_act_fuse_pass") list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_element_groupnorm_act_fuse_pass") + list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_groupnorm_act_pass_fuse_pass") list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_fused_token_prune") list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_fused_token_prune") endif() @@ -225,6 +226,8 @@ if(WITH_GPU AND TENSORRT_FOUND) PROPERTIES TIMEOUT 120) set_tests_properties(test_preln_groupnorm_act_fuse_pass PROPERTIES TIMEOUT 120) + set_tests_properties(test_groupnorm_act_pass_fuse_pass PROPERTIES TIMEOUT + 120) endif() endif() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_groupnorm_act_pass_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_groupnorm_act_pass_fuse_pass.py new file mode 100644 index 0000000000..c9f821b21d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_groupnorm_act_pass_fuse_pass.py @@ -0,0 +1,150 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + +import paddle.inference as paddle_infer + + +class TestElementGNActPass(PassAutoScanTest): + # + # | fuse | + # groupnorm -> groupnorm(with_silu) + # | | + # silu + # | + # + # + + def sample_predictor_configs(self, program_config): + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Half, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + { + "input_data": [1, 160, 1, 1], + }, + { + "input_data": [4, 1280, 64, 64], + }, + { + "input_data": [1, 320, 32, 32], + }, + ) + yield config, ['group_norm'], (3e-3, 1e-3) + + def sample_program_config(self, draw): + axis = draw(st.sampled_from([0, -1])) + epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001)) + batch_size = draw(st.integers(min_value=1, max_value=4)) + + groups = draw(st.sampled_from([4, 8, 16, 32])) + hw = draw(st.sampled_from([1, 8, 16, 32])) + channel = draw(st.sampled_from([320, 1280])) + + def generate_input(attrs): + return np.random.random( + [attrs[1]["batch_size"], *attrs[1]["input_dim"]] + ).astype(np.float32) + + def generate_weight(attrs): + return np.random.random(attrs[1]['input_dim'][0]).astype(np.float32) + + attrs = [ + { + 'epsilon': epsilon, + 'groups': groups, + }, + { + 'batch_size': batch_size, + 'input_dim': [channel, hw, hw], + }, + ] + + group_norm_op = OpConfig( + type="group_norm", + inputs={ + "X": ["input_data"], + "Bias": ["group_norm_bias"], + "Scale": ["group_norm_scale"], + }, + outputs={ + "Y": ["group_norm_output1"], + "Mean": ["group_norm_output2"], + "Variance": ["group_norm_output3"], + }, + attrs={ + "data_layout": "NCHW", + "groups": attrs[0]["groups"], + "epsilon": attrs[0]["epsilon"], + }, + ) + silu_op = OpConfig( + type="silu", + inputs={ + "X": ["group_norm_output1"], + }, + outputs={ + "Out": ["silu_output"], + }, + ) + + program_config = ProgramConfig( + ops=[ + group_norm_op, + silu_op, + ], + weights={ + "group_norm_bias": TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + "group_norm_scale": TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + }, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input, attrs) + ), + }, + outputs=["silu_output"], + ) + + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=50, + passes=["groupnorm_act_pass"], + max_duration=250, + min_success_num=50, + ) + + +if __name__ == "__main__": + unittest.main() -- GitLab