diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 18b3dd6663bc476ceb55499472a53b88e625fffd..8a22eb87db7f9ea6921b64f8844d0ac25b68effa 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -139,6 +139,8 @@ if(WITH_TENSORRT) pass_library(layernorm_shift_partition_fuse_pass inference) pass_library(reverse_roll_fuse_pass inference) pass_library(preln_layernorm_x_fuse_pass inference) + pass_library(elementwise_groupnorm_act_pass inference) + pass_library(preln_elementwise_groupnorm_act_pass inference) endif() if(WITH_TENSORRT) diff --git a/paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.cc b/paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b9d5b938b3962a9f1614e7a34fe2ca4be7aff79 --- /dev/null +++ b/paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.cc @@ -0,0 +1,204 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.h" + +#include + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +class Node; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct SkipGroupNormAct : public PatternBase { + SkipGroupNormAct(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "skip_groupnorm_act") {} + + void operator()(PDNode *x, PDNode *y); + // declare operator node's name + PATTERN_DECL_NODE(elementwise); + PATTERN_DECL_NODE(group_norm); + // declare variable node's name + PATTERN_DECL_NODE(elementwise_out); + + PATTERN_DECL_NODE(group_norm_bias); + PATTERN_DECL_NODE(group_norm_scale); + PATTERN_DECL_NODE(group_norm_out); + PATTERN_DECL_NODE(act); + PATTERN_DECL_NODE(act_out); +}; + +void SkipGroupNormAct::operator()(PDNode *x, PDNode *y) { + auto *elementwise = pattern->NewNode(elementwise_repr()) + ->assert_is_op("elementwise_add") + ->assert_has_n_outputs(1); + auto *elementwise_out_var = + pattern->NewNode(elementwise_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->assert_has_n_outputs(1) + ->assert_is_op_input("group_norm", "X"); + + elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var}); + // Create nodes for group_norm op. + auto *group_norm = + pattern->NewNode(group_norm_repr())->assert_is_op("group_norm"); + auto *group_norm_bias_var = pattern->NewNode(group_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("group_norm", "Bias"); + + auto *group_norm_scale_var = pattern->NewNode(group_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("group_norm", "Scale"); + + auto *group_norm_out_var = pattern->NewNode(group_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("group_norm", "Y") + ->assert_is_op_input("silu", "X"); + + // Add links for group_norm op. + group_norm + ->LinksFrom( + {elementwise_out_var, group_norm_bias_var, group_norm_scale_var}) + .LinksTo({group_norm_out_var}); + + auto *act = pattern->NewNode(act_repr())->assert_is_op("silu"); + auto *act_out = pattern->NewNode(act_out_repr()) + ->AsOutput() + ->assert_is_op_output("silu", "Out"); + + act->LinksFrom({group_norm_out_var}).LinksTo({act_out}); +} + +} // namespace patterns + +int SkipGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init("skip_groupnorm_silu_fuse", graph); + + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + PDNode *x = nullptr; + PDNode *y = nullptr; + + x = gpd.mutable_pattern() + ->NewNode("skip_groupnorm_act_fuse/x") + ->AsInput() + ->assert_var_not_persistable() + ->assert_is_op_input("elementwise_add", "X"); + + y = gpd.mutable_pattern() + ->NewNode("skip_groupnorm_act_fuse/y") + ->AsInput() + ->assert_var_not_persistable() + ->assert_is_op_input("elementwise_add", "Y") + ->assert_more([&](Node *x) { + auto shape = x->Var()->GetShape(); + if (shape.size() == 2 || + (shape.size() == 4 && shape[3] == 1 && shape[2] == 1)) + return true; + else + return false; + }); + patterns::SkipGroupNormAct fused_pattern(gpd.mutable_pattern(), + "skip_groupnorm_act_fuse"); + fused_pattern(x, y); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + + VLOG(4) << "handle skip groupnorm act fuse"; + + GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(group_norm, group_norm, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(group_norm_bias, group_norm_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + group_norm_scale, group_norm_scale, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(group_norm_out, group_norm_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act, act, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fused_pattern); + + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "skip groupnorm act pass in op compat failed."; + return; + } + + std::unordered_set del_node_set; + // Create an skip_groupnorm_act op node + OpDesc new_desc(*group_norm->Op()); + new_desc.SetType("skip_groupnorm_act"); + new_desc.SetInput("X", {subgraph.at(x)->Name()}); + new_desc.SetInput("Y", {subgraph.at(y)->Name()}); + new_desc.SetOutput("Out", {act_out->Name()}); + new_desc.RemoveOutput("Y"); + new_desc.Flush(); + + auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + + del_node_set.insert(elementwise); + del_node_set.insert(group_norm); + del_node_set.insert(elementwise_out); + del_node_set.insert(group_norm_out); + del_node_set.insert(act); + GraphSafeRemoveNodes(graph, del_node_set); + + IR_NODE_LINK_TO(subgraph.at(x), fused_node); + IR_NODE_LINK_TO(subgraph.at(y), fused_node); + IR_NODE_LINK_TO(group_norm_scale, fused_node); + IR_NODE_LINK_TO(group_norm_bias, fused_node); + IR_NODE_LINK_TO(fused_node, act_out); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +void SkipGroupNormActFusePass::ApplyImpl(ir::Graph *graph) const { + FusePassBase::Init("skip_groupnorm_act_fuse_pass", graph); + int found_subgraph_count = ApplyGNSiluPattern(graph); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(elementwise_groupnorm_act_pass, + paddle::framework::ir::SkipGroupNormActFusePass); +REGISTER_PASS_CAPABILITY(elementwise_groupnorm_act_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("silu", 0) + .EQ("group_norm", 0)); diff --git a/paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.h b/paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..86540ea6effe12432c2de76b953ac426ffeb34f4 --- /dev/null +++ b/paddle/fluid/framework/ir/elementwise_groupnorm_act_pass.h @@ -0,0 +1,98 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +// +// +// +// | | +// elementwise_add fuse | | +// | -> skip_gn_act +// group_norm | +// | +// silu +// | + +class Graph; + +class SkipGroupNormActFusePass : public FusePassBase { + public: + SkipGroupNormActFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({0, -1}) + .End(); + AddOpCompat(OpCompat("group_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(1.0f) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("data_layout") + .IsStringIn({"NCHW"}) + .End(); + AddOpCompat(OpCompat("silu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + } + + virtual ~SkipGroupNormActFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + int ApplyGNSiluPattern(ir::Graph* graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.cc b/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..478c315b9e37df830b0df3d20fbe665a65be3c5b --- /dev/null +++ b/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.cc @@ -0,0 +1,196 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h" +#include + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +class Node; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct PrelnGroupNormAct : public PatternBase { + PrelnGroupNormAct(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "preln_groupnorm_act") {} + + void operator()(PDNode *x, PDNode *y); + // declare operator node's name + PATTERN_DECL_NODE(elementwise); + PATTERN_DECL_NODE(group_norm); + // declare variable node's name + PATTERN_DECL_NODE(elementwise_out); + + PATTERN_DECL_NODE(group_norm_bias); + PATTERN_DECL_NODE(group_norm_scale); + PATTERN_DECL_NODE(group_norm_out); + PATTERN_DECL_NODE(act); + PATTERN_DECL_NODE(act_out); +}; + +void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y) { + auto *elementwise = + pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); + + auto *elementwise_out_var = + pattern->NewNode(elementwise_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("group_norm", "X"); + + elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var}); + // Create nodes for group_norm op. + auto *group_norm = + pattern->NewNode(group_norm_repr())->assert_is_op("group_norm"); + auto *group_norm_bias_var = pattern->NewNode(group_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("group_norm", "Bias"); + + auto *group_norm_scale_var = pattern->NewNode(group_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("group_norm", "Scale"); + + auto *group_norm_out_var = pattern->NewNode(group_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("group_norm", "Y") + ->assert_is_op_input("silu", "X"); + + // Add links for group_norm op. + group_norm + ->LinksFrom( + {elementwise_out_var, group_norm_bias_var, group_norm_scale_var}) + .LinksTo({group_norm_out_var}); + + auto *act = pattern->NewNode(act_repr())->assert_is_op("silu"); + auto *act_out = pattern->NewNode(act_out_repr()) + ->AsOutput() + ->assert_is_op_output("silu", "Out"); + + act->LinksFrom({group_norm_out_var}).LinksTo({act_out}); +} + +} // namespace patterns + +int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init("preln_groupnorm_silu_fuse", graph); + + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + PDNode *x = nullptr; + PDNode *y = nullptr; + + x = gpd.mutable_pattern() + ->NewNode("preln_groupnorm_act_fuse/x") + ->AsInput() + ->assert_var_not_persistable() + ->assert_is_op_input("elementwise_add", "X"); + + y = gpd.mutable_pattern() + ->NewNode("preln_groupnorm_act_fuse/y") + ->AsInput() + ->assert_var_not_persistable() + ->assert_is_op_input("elementwise_add", "Y"); + + patterns::PrelnGroupNormAct fused_pattern(gpd.mutable_pattern(), + "preln_groupnorm_act_fuse"); + fused_pattern(x, y); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + + VLOG(4) << "handle preln groupnorm act fuse"; + + GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(group_norm, group_norm, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(group_norm_bias, group_norm_bias, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + group_norm_scale, group_norm_scale, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(group_norm_out, group_norm_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act, act, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fused_pattern); + + if (!IsCompat(subgraph, graph)) { + LOG(WARNING) << "preln groupnorm act pass in op compat failed."; + return; + } + + std::unordered_set del_node_set; + // Create an preln_groupnorm_act op node + OpDesc new_desc(*group_norm->Op()); + new_desc.SetType("preln_groupnorm_act"); + new_desc.SetInput("X", {subgraph.at(x)->Name()}); + new_desc.SetInput("Y", {subgraph.at(y)->Name()}); + new_desc.SetOutput("Out_0", {elementwise_out->Name()}); + new_desc.SetOutput("Out_1", {act_out->Name()}); + new_desc.RemoveOutput("Y"); + new_desc.Flush(); + + auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + + del_node_set.insert(elementwise); + del_node_set.insert(group_norm); + del_node_set.insert(group_norm_out); + del_node_set.insert(act); + GraphSafeRemoveNodes(graph, del_node_set); + + IR_NODE_LINK_TO(subgraph.at(x), fused_node); + IR_NODE_LINK_TO(subgraph.at(y), fused_node); + IR_NODE_LINK_TO(group_norm_scale, fused_node); + IR_NODE_LINK_TO(group_norm_bias, fused_node); + IR_NODE_LINK_TO(fused_node, act_out); + IR_NODE_LINK_TO(fused_node, elementwise_out); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +void PrelnGroupNormActFusePass::ApplyImpl(ir::Graph *graph) const { + FusePassBase::Init("preln_groupnorm_act_fuse_pass", graph); + int found_subgraph_count = ApplyGNSiluPattern(graph); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(preln_elementwise_groupnorm_act_pass, + paddle::framework::ir::PrelnGroupNormActFusePass); +REGISTER_PASS_CAPABILITY(preln_elementwise_groupnorm_act_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("elementwise_add", 1) + .EQ("silu", 0) + .EQ("group_norm", 0)); diff --git a/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h b/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..367f5101f48518fbe02498be1146a25fffe2a6b3 --- /dev/null +++ b/paddle/fluid/framework/ir/preln_elementwise_groupnorm_act_pass.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { +// +// | | +// elementwise_add fuse | | +// | | -> preln_gn_act +// other op group_norm | | +// | other op +// silu +// | + +class Graph; + +class PrelnGroupNormActFusePass : public FusePassBase { + public: + PrelnGroupNormActFusePass() { + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({0, -1}) + .End(); + AddOpCompat(OpCompat("group_norm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddOutput("Mean") + .IsTensor() + .End() + .AddOutput("Variance") + .IsTensor() + .End() + .AddAttr("epsilon") + .IsNumGE(0.0f) + .IsNumLE(1.0f) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("data_layout") + .IsStringIn({"NCHW"}) + .End(); + AddOpCompat(OpCompat("silu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + } + + virtual ~PrelnGroupNormActFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + int ApplyGNSiluPattern(ir::Graph* graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 0fb11279ebdf9cb78b316acfcaa2e08d73048b6b..81b6f9206e4cc413ea2e6d2182ce9c4653c47fc8 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2420,6 +2420,8 @@ USE_TRT_CONVERTER(logsigmoid) USE_TRT_CONVERTER(lookup_table) USE_TRT_CONVERTER(expand_v2) USE_TRT_CONVERTER(take_along_axis) +USE_TRT_CONVERTER(skip_groupnorm_act) +USE_TRT_CONVERTER(preln_groupnorm_act) #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) USE_TRT_CONVERTER(sparse_fc) USE_TRT_CONVERTER(sparse_multihead_matmul) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index b4f511ba9ebedd83d8a9df943cffb3c7eb55d3da..c6c3e3b05fb88df0203c51cbe233613d96ea38f7 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -106,8 +106,8 @@ const std::vector kTRTSubgraphPasses({ "vit_attention_fuse_pass", // #if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading. #else - "trt_skip_layernorm_fuse_pass", // - "preln_skip_layernorm_fuse_pass", // + "trt_skip_layernorm_fuse_pass", // + "preln_skip_layernorm_fuse_pass", // #endif "layernorm_shift_partition_fuse_pass", // "merge_layernorm_fuse_pass", // @@ -128,8 +128,13 @@ const std::vector kTRTSubgraphPasses({ // "yolo_box_fuse_pass", // "dense_fc_to_sparse_pass", // "dense_multihead_matmul_to_sparse_pass", // - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // +#if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading. +#else + "elementwise_groupnorm_act_pass", // + "preln_elementwise_groupnorm_act_pass", // +#endif + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 // cudnn8.0 has memory leak problem in conv + eltwise + act, so we diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 617898a4e54b928a21c503226164ed62e8611e6a..76c74d55d1157c2ed3e887a1a283796fb6b5935e 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -94,6 +94,8 @@ list( skip_merge_layernorm_op.cc generic_and_custom_plugin_creater.cc fused_lookup_tables_op.cc + skip_groupnorm_act_op.cc + preln_groupnorm_act_op.cc expand_v2_op.cc) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) diff --git a/paddle/fluid/inference/tensorrt/convert/preln_groupnorm_act_op.cc b/paddle/fluid/inference/tensorrt/convert/preln_groupnorm_act_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..27d283f672ddb7ef7966af4ec93092fb6c14cc09 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/preln_groupnorm_act_op.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h" + +#include + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/engine.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class PrelnGroupnormActOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert a fluid preln_groupnorm_act op to tensorrt " + "preln_groupnorm_act plugin"; + + framework::OpDesc op_desc(op, nullptr); + + auto* input_x = engine_->GetITensor(op_desc.Input("X").front()); + auto* input_y = engine_->GetITensor(op_desc.Input("Y").front()); + std::vector inputs{input_x, input_y}; + + int groups = PADDLE_GET_CONST(int, op_desc.GetAttr("groups")); + float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); + + std::string scale_name = op_desc.Input("Scale").front(); + std::string bias_name = op_desc.Input("Bias").front(); + + // get the presistable var's data + auto GetWeight = [&](const std::string& var_name, + framework::DDim* dims) -> TensorRTEngine::Weight { + auto* temp_var = scope.FindVar(var_name); + auto* temp_tensor = temp_var->GetMutable(); + (*dims) = temp_tensor->dims(); + + auto weight = engine_->GetTrtWeight(var_name, *temp_tensor); + return weight; + }; + + framework::DDim scale_dims; + framework::DDim bias_dims; + auto scale_weights = GetWeight(scale_name, &scale_dims); + auto bias_weights = GetWeight(bias_name, &bias_dims); + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + + if (engine_->with_dynamic_shape()) { + plugin::PrelnGroupnormActPluginDynamic* plugin = + new plugin::PrelnGroupnormActPluginDynamic( + static_cast(scale_weights.get().values), + scale_weights.get().count, + static_cast(bias_weights.get().values), + bias_weights.get().count, + epsilon, + groups, + with_fp16); + nvinfer1::ILayer* groupnorm_layer = + engine_->AddDynamicPlugin(inputs.data(), 2, plugin); + std::vector output_names; + output_names.emplace_back(op_desc.Output("Out_0").front()); + output_names.emplace_back(op_desc.Output("Out_1").front()); + RreplenishLayerAndOutput( + groupnorm_layer, "preln_groupnorm_act", output_names, test_mode); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(preln_groupnorm_act, PrelnGroupnormActOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/skip_groupnorm_act_op.cc b/paddle/fluid/inference/tensorrt/convert/skip_groupnorm_act_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a109aeeeff490dabbc3edef81e7ec375f14c925a --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/skip_groupnorm_act_op.cc @@ -0,0 +1,92 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h" + +#include + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/engine.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class SkipGroupnormActOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(4) << "convert a fluid skip_groupnorm_act op to tensorrt " + "skip_groupnorm_act plugin"; + + framework::OpDesc op_desc(op, nullptr); + + auto* inputx = engine_->GetITensor(op_desc.Input("X").front()); + auto* inputy = engine_->GetITensor(op_desc.Input("Y").front()); + std::vector inputs{inputx, inputy}; + + int groups = PADDLE_GET_CONST(int, op_desc.GetAttr("groups")); + float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); + + std::string scale_name = op_desc.Input("Scale").front(); + std::string bias_name = op_desc.Input("Bias").front(); + + // get the presistable var's data + auto GetWeight = [&](const std::string& var_name, + framework::DDim* dims) -> TensorRTEngine::Weight { + auto* temp_var = scope.FindVar(var_name); + auto* temp_tensor = temp_var->GetMutable(); + (*dims) = temp_tensor->dims(); + + auto weight = engine_->GetTrtWeight(var_name, *temp_tensor); + return weight; + }; + + framework::DDim scale_dims; + framework::DDim bias_dims; + auto scale_weights = GetWeight(scale_name, &scale_dims); + auto bias_weights = GetWeight(bias_name, &bias_dims); + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + + if (engine_->with_dynamic_shape()) { + plugin::SkipGroupnormActPluginDynamic* plugin = + new plugin::SkipGroupnormActPluginDynamic( + static_cast(scale_weights.get().values), + scale_weights.get().count, + static_cast(bias_weights.get().values), + bias_weights.get().count, + epsilon, + groups, + with_fp16); + nvinfer1::ILayer* groupnorm_layer = + engine_->AddDynamicPlugin(inputs.data(), 2, plugin); + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput( + groupnorm_layer, "skip_groupnorm_act", {output_name}, test_mode); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(skip_groupnorm_act, SkipGroupnormActOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index b49ebbff55d80d44d84c79a03e5a0f923d04c37b..bf0730f6debd06e05459ee96c783c0124c8febab 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2390,6 +2390,22 @@ struct SimpleOpTypeSetTeller : public Teller { } } + if (op_type == "skip_groupnorm_act") { + if (!with_dynamic_shape) { + VLOG(3) << "The skip_groupnorm_act op does not support " + "static shape yet"; + return false; + } + } + + if (op_type == "preln_groupnorm_act") { + if (!with_dynamic_shape) { + VLOG(3) << "The preln_groupnorm_act op does not support " + "static shape yet"; + return false; + } + } + if (op_type == "lookup_table") { if (!with_dynamic_shape) { VLOG(3) << "the lookup_table does not support " @@ -2561,7 +2577,9 @@ struct SimpleOpTypeSetTeller : public Teller { "merge_layernorm", "skip_merge_layernorm", "lookup_table_v2", - "expand_v2"}; + "expand_v2", + "skip_groupnorm_act", + "preln_groupnorm_act"}; std::unordered_set teller_set{ "mul", @@ -2709,7 +2727,9 @@ struct SimpleOpTypeSetTeller : public Teller { "skip_merge_layernorm", "lookup_table", "lookup_table_v2", - "expand_v2"}; + "expand_v2", + "skip_groupnorm_act", + "preln_groupnorm_act"}; }; struct GenericPluginTeller : public Teller { diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 4e4d9acdcb874a616fa64b083800c47156642ee0..b0d2eb6d347466a14249c596597eab662a5b540e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -35,6 +35,8 @@ list( prelnlayernorm_shift_partition_op.cu merge_layernorm_op_plugin.cu skip_merge_layernorm_op_plugin.cu + skip_groupnorm_act_op_plugin.cu + preln_groupnorm_act_op_plugin.cu generic_plugin.cu lookup_table.cu many_emb_layernorm_plugin.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h b/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h index 3798dcf1d501b3bd0717777221fced76699a2bb5..81d507e866a1c2a1fa74b0a06c1af5ed10a00538 100644 --- a/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h +++ b/paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h @@ -27,6 +27,8 @@ namespace plugin { struct GroupNormNHWCParams { // The output buffer. Layout NHWC. __half* dst; + // The output buffer. Layout NHWC. + __half* eleOut; // The input buffer. Layout NHWC. __half const* srcX; // The input buffer. Layout NHWC. diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..3e1c8aa5f842b4d283ba0737371e958101448a42 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.cu @@ -0,0 +1,456 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h" +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { +nvinfer1::DimsExprs PrelnGroupnormActPluginDynamic::getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs *inputDims, + int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { + return inputDims[0]; +} + +bool PrelnGroupnormActPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc *in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_NOT_NULL( + in_out, + platform::errors::InvalidArgument( + "The input of prelnGroupnormAct plugin shoule not be nullptr.")); + PADDLE_ENFORCE_LT( + pos, + nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, + nb_inputs + nb_outputs)); + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + if (with_fp16_) { + return ((in.type == nvinfer1::DataType::kHALF) && + (in.format == nvinfer1::PluginFormat::kHWC8)); + } else { + PADDLE_THROW(platform::errors::Fatal( + "PrelnGroupnormAct TRT Plugin is fp16 only so far")); + + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType PrelnGroupnormActPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType *input_types, + int nb_inputs) const TRT_NOEXCEPT { + return input_types[0]; +} + +int PrelnGroupnormActPluginDynamic::initialize() TRT_NOEXCEPT { return 0; } + +static inline int32_t divUp(int32_t m, int32_t n) { return (m + n - 1) / n; } + +static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { + int32_t maxDivisor = -1; + for (int32_t i = 1; i <= std::sqrt(n); i++) { + if (n % i == 0) { + int32_t divisor1 = n / i; + int32_t divisor2 = i; + + if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { + maxDivisor = divisor1; + } + if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { + maxDivisor = divisor2; + } + } + } + return maxDivisor; +} + +static inline __device__ __host__ float sigmoid(float x) { + return 1.F / (1.F + expf(-x)); +} + +struct GroupSums { + // Is it the 1st element of the group? + int32_t flag; + // The sum. + float sum; + // The sum of squares. + float sumSq; +}; + +struct GroupSumsOp { + inline __device__ GroupSums operator()(GroupSums const &a, + GroupSums const &b) { + GroupSums dst; + dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.flag = a.flag + b.flag; + return dst; + } +}; + +template +__global__ void prelnGroupNormNHWCSumKernel(GroupNormNHWCParams params) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + + // Allocate shared memory for BlockScan. + __shared__ typename BlockScan::TempStorage tempStorage; + // Allocate shared memory for the groups. We could reduce the amount of shared + // memory reserved. + __shared__ float2 smem[tTHREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // The sums. + float sum = 0.F; + float sumSq = 0.F; + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The offset. + int64_t offset = static_cast(ni) * params.hwc + + static_cast(hwi) * params.c + ci; + // Fetch two channels per thread. + __half2 h2(0, 0); + if (ci < params.c) { + // int64_t offsetY = static_cast(ni) * params.c + ci; + __half2 y = *reinterpret_cast<__half2 const *>(¶ms.srcY[offset]); + h2 = *reinterpret_cast<__half2 const *>(¶ms.srcX[offset]); + h2 = __hadd2(h2, y); + // elementwise_add + *reinterpret_cast<__half2 *>(¶ms.eleOut[offset]) = h2; + } + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + // Update the sum of squares. + sumSq += f2.x * f2.x + f2.y * f2.y; + } + + // The group that thread works on and the channel in the group (modulus). + int32_t gi = threadIdx.x * 2 / params.cPerGroup; + int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + + // The data for the summations. + GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + + // Do the segmented scan. + GroupSums out; + BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + + // Store the results for the groups in shared memory (to produce coalesced + // stores later). + if (cj == params.cPerGroup - 2 /* 2 channels per thread */) { + smem[gi] = make_float2(out.sum, out.sumSq); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The global group index. + int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; + + // Threads that have nothing left to do, exit. + if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + return; + } + + // The first threads (those storing to global memory, load the values). + float2 sums = smem[threadIdx.x]; + + // Store to global memory. + atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); + atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); +} + +void prelnGroupNormNHWCSum(GroupNormNHWCParams const ¶ms, + cudaStream_t stream) { + // Make sure the values are as we expect. + PADDLE_ENFORCE_EQ(params.c % params.cPerBlock, + 0, + platform::errors::InvalidArgument( + "The groupNormNHWCSum of prelnGroupnormAct Plugin got " + "wrong parameters" + "params.c %% params.cPerBlock should be 0, but get %d.", + params.c % params.cPerBlock)); + PADDLE_ENFORCE_EQ( + params.hw % params.hwPerBlock, + 0, + platform::errors::InvalidArgument( + "The groupNormNHWCSum of prelnGroupnormAct Plugin got wrong " + "parameters" + "params.hw %% params.hwPerBlock should be 0, but get %d.", + params.hw % params.hwPerBlock)); + // Make sure a group does not span multiple blocks. + PADDLE_ENFORCE_EQ( + params.cPerBlock % params.cPerGroup, + 0, + platform::errors::InvalidArgument( + "The groupNormNHWCSum of prelnGroupnormAct Plugin got wrong " + "parameters" + "params.cPerBlock %% params.cPerGroup should be 0, but get %d.", + params.cPerBlock % params.cPerGroup)); + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + prelnGroupNormNHWCSumKernel<160><<>>(params); + break; + case 480: + prelnGroupNormNHWCSumKernel<256><<>>(params); + break; + case 256: + prelnGroupNormNHWCSumKernel<128><<>>(params); + break; + case 128: + prelnGroupNormNHWCSumKernel<64><<>>(params); + break; + default: + PADDLE_THROW(platform::errors::Fatal( + "The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin " + "encounter error")); + } +} + +template +__global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + // The group that thread works on and the channel in the group (modulus). + int32_t gi = ci / params.cPerGroup; + + // Load the sum and sum of squares for the group. + float sum = 0.F, sumSq = 0.F; + if (gi < params.groups) { + sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; + sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + } + + // Load gamma/beta. + float2 gammaF2, betaF2; + if (ci < params.c) { + gammaF2 = *reinterpret_cast( + reinterpret_cast(params.gamma) + ci); + betaF2 = *reinterpret_cast( + reinterpret_cast(params.beta) + ci); + } + + // Compute the mean. + float mean = sum * params.invHWC; + // Compute the variance. + float var = sumSq * params.invHWC - (mean * mean); + // Compute the inverse of the stddev. + float invStdDev = rsqrtf(var + params.eps); + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The src/dst offset. + int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; + + // Fetch two channels per thread. + __half2 h2(0, 0); + if (ci < params.c) { + h2 = *reinterpret_cast<__half2 const *>(¶ms.eleOut[offset]); + } + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Swish if needed. + if (params.withSwish) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + // Store the scaled values. + if (ci < params.c) { + *reinterpret_cast<__half2 *>(¶ms.dst[offset]) = __float22half2_rn(f2); + } + } +} + +void prelnGroupNormNHWCScale(GroupNormNHWCParams const ¶ms, + cudaStream_t stream) { + // Make sure the dimensions are aligned with what we expect. + PADDLE_ENFORCE_EQ( + params.c % params.cPerBlock, + 0, + platform::errors::InvalidArgument( + "The groupNormNHWCScale of prelnGroupnormAct Plugin got " + "wrong parameters" + "params.c %% params.cPerBlock should be 0, but get %d.", + params.c % params.cPerBlock)); + // Make sure a group does not span multiple blocks. + PADDLE_ENFORCE_EQ( + params.cPerBlock % params.cPerGroup, + 0, + platform::errors::InvalidArgument( + "The groupNormNHWCScale of prelnGroupnormAct Plugin got wrong " + "parameters" + "params.cPerBlock %% params.cPerGroup should be 0, but get %d.", + params.cPerBlock % params.cPerGroup)); + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + prelnGroupNormNHWCScaleKernel<160><<>>(params); + break; + case 480: + prelnGroupNormNHWCScaleKernel<256><<>>(params); + break; + case 256: + prelnGroupNormNHWCScaleKernel<128><<>>(params); + break; + case 128: + prelnGroupNormNHWCScaleKernel<64><<>>(params); + break; + default: + PADDLE_THROW(platform::errors::Fatal( + "The function groupNormNHWCSum of prelnGroupnormAct TRT Plugin " + "encounter error")); + } +} + +int PrelnGroupnormActPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, + const void *const *inputs, + void *const *outputs, + void *workspace, + cudaStream_t stream) TRT_NOEXCEPT { + auto input_type = input_desc[0].type; + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(1) << "TRT Plugin DataType selected. prelnGroupnormAct-->fp32"; + PADDLE_THROW(platform::errors::Fatal( + "The prelnGroupnormAct TRT Plugin's only support fp16 input")); + } else if (input_type == nvinfer1::DataType::kHALF) { + VLOG(1) << "TRT Plugin DataType selected. prelnGroupnormAct-->fp16"; + + int32_t cPerBlock = 320; + int32_t maxBlocksPerHW = 1024; + + switch (input_desc[0].dims.d[1]) { + case 960: + case 1920: + cPerBlock = 480; + break; + case 512: + case 256: + cPerBlock = 256; + break; + case 128: + cPerBlock = 128; + break; + default: + cPerBlock = 320; + } + params_.withSwish = true; + params_.dst = static_cast(outputs[1]); + params_.eleOut = static_cast(outputs[0]); + params_.srcX = static_cast(inputs[0]); + params_.srcY = static_cast(inputs[1]); + params_.gamma = scale_gpu_.get(); + params_.beta = bias_gpu_.get(); + params_.redBuffer = static_cast(workspace); + params_.n = input_desc[0].dims.d[0]; + params_.h = input_desc[0].dims.d[2]; + params_.w = input_desc[0].dims.d[3]; + params_.c = input_desc[0].dims.d[1]; + params_.groups = groups_; + params_.hw = params_.h * params_.w; + const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW); + params_.hwPerBlock = divUp(params_.hw, blocksPerHW); + params_.cPerBlock = cPerBlock; + params_.cPerGroup = params_.c / params_.groups; + params_.hwc = params_.hw * params_.c; + params_.invHWC = 1.F / static_cast(params_.hw * params_.cPerGroup); + params_.groupsPerBlock = cPerBlock / params_.cPerGroup; + params_.eps = eps_; + + cudaMemsetAsync(params_.redBuffer, 0, ws_, stream); + prelnGroupNormNHWCSum(params_, stream); + prelnGroupNormNHWCScale(params_, stream); + + } else { + // input not fp16 + PADDLE_THROW(platform::errors::Fatal( + "The PrelnGroupnormAct TRT Plugin's only support fp16 input")); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..70f5769f1c343d6ce4932c4fa5e8cb8798251d8c --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h @@ -0,0 +1,194 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { +class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { + public: + PrelnGroupnormActPluginDynamic(const float* scale, + const int scale_num, + const float* bias, + const int bias_num, + float eps, + int groups, + bool with_fp16, + std::shared_ptr scale_gpu = nullptr, + std::shared_ptr bias_gpu = nullptr) + : scale_gpu_(scale_gpu), + bias_gpu_(bias_gpu), + groups_(groups), + eps_(eps), + with_fp16_(with_fp16) { + scale_.resize(scale_num); + bias_.resize(bias_num); + std::copy(scale, scale + scale_num, scale_.data()); + std::copy(bias, bias + bias_num, bias_.data()); + if (scale_gpu_ == nullptr) { + void* p; + cudaMalloc(reinterpret_cast(&p), scale_num * sizeof(float)); + scale_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); }); + cudaMemcpy( + p, scale_.data(), scale_num * sizeof(float), cudaMemcpyHostToDevice); + } + if (bias_gpu_ == nullptr) { + void* p; + cudaMalloc(reinterpret_cast(&p), bias_num * sizeof(float)); + bias_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); }); + cudaMemcpy( + p, bias_.data(), bias_num * sizeof(float), cudaMemcpyHostToDevice); + } + } + + PrelnGroupnormActPluginDynamic(void const* serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &bias_); + DeserializeValue(&serialData, &serialLength, &eps_); + DeserializeValue(&serialData, &serialLength, &groups_); + DeserializeValue(&serialData, &serialLength, &with_fp16_); + + { + void* p; + cudaMalloc(reinterpret_cast(&p), scale_.size() * sizeof(float)); + scale_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); }); + cudaMemcpy(p, + scale_.data(), + scale_.size() * sizeof(float), + cudaMemcpyHostToDevice); + } + { + void* p; + cudaMalloc(reinterpret_cast(&p), bias_.size() * sizeof(float)); + bias_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); }); + cudaMemcpy(p, + bias_.data(), + bias_.size() * sizeof(float), + cudaMemcpyHostToDevice); + } + } + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + auto* ptr = new PrelnGroupnormActPluginDynamic(scale_.data(), + scale_.size(), + bias_.data(), + bias_.size(), + eps_, + groups_, + with_fp16_, + scale_gpu_, + bias_gpu_); + return ptr; + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "preln_groupnorm_act_plugin_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 2; } + int initialize() TRT_NOEXCEPT override; + + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(scale_) + SerializedSize(bias_) + + SerializedSize(eps_) + SerializedSize(groups_) + + SerializedSize(with_fp16_); + } + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, bias_); + SerializeValue(&buffer, eps_); + SerializeValue(&buffer, groups_); + SerializeValue(&buffer, with_fp16_); + } + nvinfer1::DimsExprs getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) // NOLINT + TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override { + // sizeof(float2) * maxBatchSize * maxNumberOfGroup. float2 + // contians two buffers for sum and squared sum; + ws_ = sizeof(float) * 2 * in[0].max.d[0] * groups_; + } + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return ws_; + } + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + void terminate() TRT_NOEXCEPT override{}; + + private: + size_t ws_; + std::vector scale_; + std::vector bias_; + std::shared_ptr scale_gpu_; + std::shared_ptr bias_gpu_; + GroupNormNHWCParams params_; + int groups_; + float eps_; + bool with_fp16_; +}; + +class PrelnGroupnormActPluginDynamicCreator : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "preln_groupnorm_act_plugin_dynamic"; + } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + return new PrelnGroupnormActPluginDynamic(serial_data, serial_length); + } +}; +REGISTER_TRT_PLUGIN_V2(PrelnGroupnormActPluginDynamicCreator); +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..7acd250c50cf72f3913e5cbfa72f0b80d657aba6 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.cu @@ -0,0 +1,463 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h" +#include +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { +nvinfer1::DimsExprs SkipGroupnormActPluginDynamic::getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs *inputDims, + int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { + return inputDims[0]; +} + +bool SkipGroupnormActPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc *in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_NOT_NULL( + in_out, + platform::errors::InvalidArgument( + "The input of SkipGroupnormAct plugin shoule not be nullptr.")); + PADDLE_ENFORCE_LT( + pos, + nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, + nb_inputs + nb_outputs)); + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + if (with_fp16_) { + return ((in.type == nvinfer1::DataType::kHALF) && + (in.format == nvinfer1::PluginFormat::kHWC8)); + } else { + PADDLE_THROW(platform::errors::Fatal( + "SkipGroupnormAct TRT Plugin is fp16 only so far")); + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType SkipGroupnormActPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType *input_types, + int nb_inputs) const TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ( + index, + 0, + platform::errors::InvalidArgument( + "The SkipGroupnormAct Plugin only has one input, so the " + "index value should be 0, but get %d.", + index)); + PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT || + input_types[0] == nvinfer1::DataType::kHALF), + true, + platform::errors::InvalidArgument( + "The input type should be half or float")); + + return input_types[0]; +} +int SkipGroupnormActPluginDynamic::initialize() TRT_NOEXCEPT { return 0; } + +static inline int32_t divUp(int32_t m, int32_t n) { return (m + n - 1) / n; } + +static int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { + int32_t maxDivisor = -1; + for (int32_t i = 1; i <= std::sqrt(n); i++) { + if (n % i == 0) { + int32_t divisor1 = n / i; + int32_t divisor2 = i; + + if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { + maxDivisor = divisor1; + } + if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { + maxDivisor = divisor2; + } + } + } + return maxDivisor; +} + +static inline __device__ __host__ float sigmoid(float x) { + return 1.F / (1.F + expf(-x)); +} + +struct GroupSums { + // Is it the 1st element of the group? + int32_t flag; + // The sum. + float sum; + // The sum of squares. + float sumSq; +}; + +struct GroupSumsOp { + inline __device__ GroupSums operator()(GroupSums const &a, + GroupSums const &b) { + GroupSums dst; + dst.sum = b.flag ? b.sum : (a.sum + b.sum); + dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.flag = a.flag + b.flag; + return dst; + } +}; + +template +__global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams params) { + // The object in charge of doing the sums for the different blocks. + typedef cub::BlockScan BlockScan; + + // Allocate shared memory for BlockScan. + __shared__ typename BlockScan::TempStorage tempStorage; + // Allocate shared memory for the groups. We could reduce the amount of shared + // memory reserved. + __shared__ float2 smem[tTHREADS_PER_BLOCK]; + + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // The sums. + float sum = 0.F; + float sumSq = 0.F; + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The offset. + int64_t offset = static_cast(ni) * params.hwc + + static_cast(hwi) * params.c + ci; + // Fetch two channels per thread. + __half2 h2(0, 0); + if (ci < params.c) { + // W = 1, H = 1 + int64_t offsetY = static_cast(ni) * params.c + ci; + __half2 y = *reinterpret_cast<__half2 const *>(¶ms.srcY[offsetY]); + h2 = *reinterpret_cast<__half2 const *>(¶ms.srcX[offset]); + h2 = __hadd2(h2, y); + // elementwise_add + *reinterpret_cast<__half2 *>(¶ms.dst[offset]) = h2; + } + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Update the sum. + sum += f2.x + f2.y; + // Update the sum of squares. + sumSq += f2.x * f2.x + f2.y * f2.y; + } + + // The group that thread works on and the channel in the group (modulus). + int32_t gi = threadIdx.x * 2 / params.cPerGroup; + int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + + // The data for the summations. + GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + + // Do the segmented scan. + GroupSums out; + BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + + // Store the results for the groups in shared memory (to produce coalesced + // stores later). + if (cj == params.cPerGroup - 2 /* 2 channels per thread */) { + smem[gi] = make_float2(out.sum, out.sumSq); + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The global group index. + int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; + + // Threads that have nothing left to do, exit. + if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + return; + } + + // The first threads (those storing to global memory, load the values). + float2 sums = smem[threadIdx.x]; + + // Store to global memory. + atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); + atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); +} + +void skipGroupNormNHWCSum(GroupNormNHWCParams const ¶ms, + cudaStream_t stream) { + // Make sure the values are as we expect. + PADDLE_ENFORCE_EQ( + params.c % params.cPerBlock, + 0, + platform::errors::InvalidArgument( + "The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters" + "params.c %% params.cPerBlock should be 0, but get %d.", + params.c % params.cPerBlock)); + PADDLE_ENFORCE_EQ( + params.hw % params.hwPerBlock, + 0, + platform::errors::InvalidArgument( + "The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters" + "params.hw %% params.hwPerBlock should be 0, but get %d.", + params.hw % params.hwPerBlock)); + // Make sure a group does not span multiple blocks. + PADDLE_ENFORCE_EQ( + params.cPerBlock % params.cPerGroup, + 0, + platform::errors::InvalidArgument( + "The groupNormNHWCSum of SkipGroupnormAct Plugin got wrong parameters" + "params.cPerBlock %% params.cPerGroup should be 0, but get %d.", + params.cPerBlock % params.cPerGroup)); + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + skipGroupNormNHWCSumKernel<160><<>>(params); + break; + case 480: + skipGroupNormNHWCSumKernel<256><<>>(params); + break; + case 256: + skipGroupNormNHWCSumKernel<128><<>>(params); + break; + case 128: + skipGroupNormNHWCSumKernel<64><<>>(params); + break; + default: + PADDLE_THROW(platform::errors::Fatal( + "The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin " + "encounter error")); + } +} + +template +__global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The instance in the batch. + int32_t ni = blockIdx.z; + // The channel loaded by that thread (2 channels per thread for F16x2). + int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + // The group that thread works on and the channel in the group (modulus). + int32_t gi = ci / params.cPerGroup; + + // Load the sum and sum of squares for the group. + float sum = 0.F, sumSq = 0.F; + if (gi < params.groups) { + sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; + sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + } + + // Load gamma/beta. + float2 gammaF2, betaF2; + if (ci < params.c) { + gammaF2 = *reinterpret_cast( + reinterpret_cast(params.gamma) + ci); + betaF2 = *reinterpret_cast( + reinterpret_cast(params.beta) + ci); + } + + // Compute the mean. + float mean = sum * params.invHWC; + // Compute the variance. + float var = sumSq * params.invHWC - (mean * mean); + // Compute the inverse of the stddev. + float invStdDev = rsqrtf(var + params.eps); + + // The first activation loaded by that block. + int32_t hwBegin = blockIdx.y * params.hwPerBlock; + // The last activation loaded by that block. + int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + + // Iterate over the activations to compute the sums. + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + // The src/dst offset. + int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; + + // Fetch two channels per thread. + __half2 h2(0, 0); + if (ci < params.c) { + h2 = *reinterpret_cast<__half2 const *>(¶ms.dst[offset]); + } + + // Extract the two half values. + float2 f2 = __half22float2(h2); + + // Normalize the channels. + f2.x = (f2.x - mean) * invStdDev; + f2.y = (f2.y - mean) * invStdDev; + + // Scale by gamma and add beta. + f2.x = gammaF2.x * f2.x + betaF2.x; + f2.y = gammaF2.y * f2.y + betaF2.y; + + // Apply Swish if needed. + if (params.withSwish) { + f2.x = f2.x * sigmoid(f2.x); + f2.y = f2.y * sigmoid(f2.y); + } + + // Store the scaled values. + if (ci < params.c) { + *reinterpret_cast<__half2 *>(¶ms.dst[offset]) = __float22half2_rn(f2); + } + } +} + +void skipGroupNormNHWCScale(GroupNormNHWCParams const ¶ms, + cudaStream_t stream) { + // Make sure the dimensions are aligned with what we expect. + PADDLE_ENFORCE_EQ(params.c % params.cPerBlock, + 0, + platform::errors::InvalidArgument( + "The groupNormNHWCScale of SkipGroupnormAct Plugin got " + "wrong parameters" + "params.c %% params.cPerBlock should be 0, but get %d.", + params.c % params.cPerBlock)); + // Make sure a group does not span multiple blocks. + PADDLE_ENFORCE_EQ( + params.cPerBlock % params.cPerGroup, + 0, + platform::errors::InvalidArgument( + "The groupNormNHWCScale of SkipGroupnormAct Plugin got wrong " + "parameters" + "params.cPerBlock %% params.cPerGroup should be 0, but get %d.", + params.cPerBlock % params.cPerGroup)); + dim3 grid; + + // The number of blocks to compute all the channels. + grid.x = params.c / params.cPerBlock; + // The number of blocks to compute all the activations in a given instance. + grid.y = divUp(params.hw, params.hwPerBlock); + // The number of instances. + grid.z = params.n; + + switch (params.cPerBlock) { + case 320: + skipGroupNormNHWCScaleKernel<160><<>>(params); + break; + case 480: + skipGroupNormNHWCScaleKernel<256><<>>(params); + break; + case 256: + skipGroupNormNHWCScaleKernel<128><<>>(params); + break; + case 128: + skipGroupNormNHWCScaleKernel<64><<>>(params); + break; + default: + PADDLE_THROW(platform::errors::Fatal( + "The function groupNormNHWCSum of SkipGroupnormAct TRT Plugin " + "encounter error")); + } +} + +int SkipGroupnormActPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, + const void *const *inputs, + void *const *outputs, + void *workspace, + cudaStream_t stream) TRT_NOEXCEPT { + auto input_type = input_desc[0].type; + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(1) << "TRT Plugin DataType selected. SkipGroupnormAct-->fp32"; + PADDLE_THROW(platform::errors::Fatal( + "The SkipGroupnormAct TRT Plugin's only support fp16 input")); + } else if (input_type == nvinfer1::DataType::kHALF) { + VLOG(1) << "TRT Plugin DataType selected. SkipGroupnormAct-->fp16"; + int32_t cPerBlock = 320; + int32_t maxBlocksPerHW = 1024; + + switch (input_desc[0].dims.d[1]) { + case 960: + case 1920: + cPerBlock = 480; + break; + case 512: + case 256: + cPerBlock = 256; + break; + case 128: + cPerBlock = 128; + break; + default: + cPerBlock = 320; + } + params_.withSwish = true; + params_.dst = static_cast(outputs[0]); + params_.srcX = static_cast(inputs[0]); + params_.srcY = static_cast(inputs[1]); + params_.gamma = scale_gpu_.get(); + params_.beta = bias_gpu_.get(); + params_.redBuffer = static_cast(workspace); + params_.n = input_desc[0].dims.d[0]; + params_.h = input_desc[0].dims.d[2]; + params_.w = input_desc[0].dims.d[3]; + params_.c = input_desc[0].dims.d[1]; + params_.groups = groups_; + params_.hw = params_.h * params_.w; + const int32_t blocksPerHW = findMaxDivisor(params_.hw, maxBlocksPerHW); + params_.hwPerBlock = divUp(params_.hw, blocksPerHW); + params_.cPerBlock = cPerBlock; + params_.cPerGroup = params_.c / params_.groups; + params_.hwc = params_.hw * params_.c; + params_.invHWC = 1.F / static_cast(params_.hw * params_.cPerGroup); + params_.groupsPerBlock = cPerBlock / params_.cPerGroup; + params_.eps = eps_; + + cudaMemsetAsync(params_.redBuffer, 0, ws_, stream); + skipGroupNormNHWCSum(params_, stream); + skipGroupNormNHWCScale(params_, stream); + + } else { + // input not fp16 + PADDLE_THROW(platform::errors::Fatal( + "The SkipGroupnormAct TRT Plugin's only support fp16 input")); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..5ed9dd14e711dd9756ef35fb0073486ea8046b64 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h @@ -0,0 +1,194 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/groupNormPluginCommon.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { +class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT { + public: + SkipGroupnormActPluginDynamic(const float* scale, + const int scale_num, + const float* bias, + const int bias_num, + float eps, + int groups, + bool with_fp16, + std::shared_ptr scale_gpu = nullptr, + std::shared_ptr bias_gpu = nullptr) + : scale_gpu_(scale_gpu), + bias_gpu_(bias_gpu), + groups_(groups), + eps_(eps), + with_fp16_(with_fp16) { + scale_.resize(scale_num); + bias_.resize(bias_num); + std::copy(scale, scale + scale_num, scale_.data()); + std::copy(bias, bias + bias_num, bias_.data()); + if (scale_gpu_ == nullptr) { + void* p; + cudaMalloc(reinterpret_cast(&p), scale_num * sizeof(float)); + scale_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); }); + cudaMemcpy( + p, scale_.data(), scale_num * sizeof(float), cudaMemcpyHostToDevice); + } + if (bias_gpu_ == nullptr) { + void* p; + cudaMalloc(reinterpret_cast(&p), bias_num * sizeof(float)); + bias_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); }); + cudaMemcpy( + p, bias_.data(), bias_num * sizeof(float), cudaMemcpyHostToDevice); + } + } + + SkipGroupnormActPluginDynamic(void const* serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &bias_); + DeserializeValue(&serialData, &serialLength, &eps_); + DeserializeValue(&serialData, &serialLength, &groups_); + DeserializeValue(&serialData, &serialLength, &with_fp16_); + + { + void* p; + cudaMalloc(reinterpret_cast(&p), scale_.size() * sizeof(float)); + scale_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); }); + cudaMemcpy(p, + scale_.data(), + scale_.size() * sizeof(float), + cudaMemcpyHostToDevice); + } + { + void* p; + cudaMalloc(reinterpret_cast(&p), bias_.size() * sizeof(float)); + bias_gpu_.reset(p, [](void* ptr) { cudaFree(ptr); }); + cudaMemcpy(p, + bias_.data(), + bias_.size() * sizeof(float), + cudaMemcpyHostToDevice); + } + } + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + auto* ptr = new SkipGroupnormActPluginDynamic(scale_.data(), + scale_.size(), + bias_.data(), + bias_.size(), + eps_, + groups_, + with_fp16_, + scale_gpu_, + bias_gpu_); + return ptr; + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "skip_groupnorm_act_plugin_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + int initialize() TRT_NOEXCEPT override; + + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(scale_) + SerializedSize(bias_) + + SerializedSize(eps_) + SerializedSize(groups_) + + SerializedSize(with_fp16_); + } + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, bias_); + SerializeValue(&buffer, eps_); + SerializeValue(&buffer, groups_); + SerializeValue(&buffer, with_fp16_); + } + nvinfer1::DimsExprs getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) // NOLINT + TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override { + // sizeof(float2) * maxBatchSize * maxNumberOfGroup. float2 + // contians two buffers for sum and squared sum; + ws_ = sizeof(float) * 2 * in[0].max.d[0] * groups_; + } + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return ws_; + } + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + void terminate() TRT_NOEXCEPT override{}; + + private: + size_t ws_; + std::vector scale_; + std::vector bias_; + std::shared_ptr scale_gpu_; + std::shared_ptr bias_gpu_; + GroupNormNHWCParams params_; + int groups_; + float eps_; + bool with_fp16_; +}; + +class SkipGroupnormActPluginDynamicCreator : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "skip_groupnorm_act_plugin_dynamic"; + } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + return new SkipGroupnormActPluginDynamic(serial_data, serial_length); + } +}; +REGISTER_TRT_PLUGIN_V2(SkipGroupnormActPluginDynamicCreator); +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/compat/group_norm.pbtxt b/paddle/fluid/operators/compat/group_norm.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..de4ac1ba38610d1f0e7f61eae4c3721787704e9c --- /dev/null +++ b/paddle/fluid/operators/compat/group_norm.pbtxt @@ -0,0 +1,67 @@ +type: "group_norm" +def { + inputs { + name: "X" + } + inputs { + name: "Scale" + } + inputs { + name: "Bias" + } + outputs { + name: "Y" + } + outputs { + name: "Mean" + } + outputs { + name: "Variance" + } + attrs { + name: "epsilon" + type: FLOAT + } + attrs { + name: "groups" + type: INT + } + attrs { + name: "data_layout" + type: STRING + } +} +extra { + attrs { + name: "use_mkldnn" + type: BOOLEAN + } + attrs { + name: "mkldnn_data_type" + type: STRING + } + attrs { + name: "is_test" + type: BOOLEAN + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} diff --git a/paddle/fluid/operators/compat/silu.pbtxt b/paddle/fluid/operators/compat/silu.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..f760caaf5894a38eb18453e7d376ab5f8bc10958 --- /dev/null +++ b/paddle/fluid/operators/compat/silu.pbtxt @@ -0,0 +1,31 @@ +type: "silu" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } +} +extra { + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 05632cff1b18f1b234f555d348328574df860b94..d5541666e9e99763447484d059fade5e72c59149 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -35,6 +35,10 @@ endif() if(WIN32) list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_trt_convert_fused_token_prune") + list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES + "test_preln_groupnorm_act_fuse_pass") + list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES + "test_element_groupnorm_act_fuse_pass") list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_fused_token_prune") list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_fused_token_prune") endif() @@ -217,6 +221,10 @@ if(WITH_GPU AND TENSORRT_FOUND) set_tests_properties(test_map_matmul_v2_to_mul_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_map_matmul_to_mul_pass PROPERTIES TIMEOUT 120) + set_tests_properties(test_element_groupnorm_act_fuse_pass + PROPERTIES TIMEOUT 120) + set_tests_properties(test_preln_groupnorm_act_fuse_pass PROPERTIES TIMEOUT + 120) endif() endif() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_element_groupnorm_act_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_element_groupnorm_act_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..4763c59620549b919341a6deb632dba8f3659660 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_element_groupnorm_act_fuse_pass.py @@ -0,0 +1,173 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + +import paddle.inference as paddle_infer + + +class TestElementGNActPass(PassAutoScanTest): + # + # | | | | + # other_op1 other_op2 other_op1 other_op2 + # | | fuse \ / + # elementwise_add -> skip_groupnorm_act + # | | + # groupnorm + # | + # silu + + def sample_predictor_configs(self, program_config): + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Half, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + { + "input_data_x": [1, 160, 1, 1], + "input_data_y": [1, 160, 1, 1], + }, + { + "input_data_x": [4, 1280, 64, 64], + "input_data_y": [4, 1280, 1, 1], + }, + { + "input_data_x": [1, 320, 1, 1], + "input_data_y": [1, 320, 1, 1], + }, + ) + yield config, ['skip_groupnorm_act'], (3e-3, 1e-3) + + def sample_program_config(self, draw): + axis = draw(st.sampled_from([0, -1])) + epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001)) + batch_size = draw(st.integers(min_value=1, max_value=4)) + + groups = draw(st.sampled_from([4, 8, 16, 32])) + hw = draw(st.sampled_from([1, 8, 16, 32])) + channel = draw(st.sampled_from([320, 1280])) + + def generate_input_x(attrs): + return np.random.random( + [attrs[1]["batch_size"], *attrs[1]["input_dim_x"]] + ).astype(np.float32) + + def generate_input_y(attrs): + return np.random.random( + [attrs[1]["batch_size"], *attrs[1]["input_dim_y"]] + ).astype(np.float32) + + def generate_weight(attrs): + return np.random.random(attrs[1]['input_dim_x'][0]).astype( + np.float32 + ) + + attrs = [ + { + 'axis': axis, + 'epsilon': epsilon, + 'groups': groups, + }, + { + 'batch_size': batch_size, + 'input_dim_x': [channel, hw, hw], + 'input_dim_y': [channel, 1, 1], + }, + ] + + elementwise_add_op = OpConfig( + type="elementwise_add", + inputs={"X": ["input_data_x"], "Y": ["input_data_y"]}, + outputs={"Out": ["ele_out"]}, + attrs={"axis": attrs[0]['axis']}, + ) + group_norm_op = OpConfig( + type="group_norm", + inputs={ + "X": ["ele_out"], + "Bias": ["group_norm_bias"], + "Scale": ["group_norm_scale"], + }, + outputs={ + "Y": ["group_norm_output1"], + "Mean": ["group_norm_output2"], + "Variance": ["group_norm_output3"], + }, + attrs={ + "data_layout": "NCHW", + "groups": attrs[0]["groups"], + "epsilon": attrs[0]["epsilon"], + }, + ) + silu_op = OpConfig( + type="silu", + inputs={ + "X": ["group_norm_output1"], + }, + outputs={ + "Out": ["silu_output"], + }, + ) + + program_config = ProgramConfig( + ops=[ + elementwise_add_op, + group_norm_op, + silu_op, + ], + weights={ + "group_norm_bias": TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + "group_norm_scale": TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + }, + inputs={ + "input_data_x": TensorConfig( + data_gen=partial(generate_input_x, attrs) + ), + "input_data_y": TensorConfig( + data_gen=partial(generate_input_y, attrs) + ), + }, + outputs=["silu_output"], + ) + + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=50, + passes=["elementwise_groupnorm_act_pass"], + max_duration=250, + min_success_num=50, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_preln_groupnorm_act_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_preln_groupnorm_act_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5633d2e7fe4af41bc20d36d96d52423eb6cd9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_preln_groupnorm_act_fuse_pass.py @@ -0,0 +1,173 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + +import paddle.inference as paddle_infer + + +class TestElementGNActPass(PassAutoScanTest): + # + # | | | | + # other_op1 other_op2 other_op1 other_op2 + # | | fuse \ / + # elementwise_add -> preln_groupnorm_act + # | | | | + # other_op3 groupnorm other_op3 + # | + # silu + + def sample_predictor_configs(self, program_config): + # trt dynamic_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=1, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Half, + use_static=False, + use_calib_mode=False, + ) + config.set_trt_dynamic_shape_info( + { + "input_data_x": [1, 160, 1, 1], + "input_data_y": [1, 160, 1, 1], + }, + { + "input_data_x": [4, 1280, 64, 64], + "input_data_y": [4, 1280, 64, 64], + }, + { + "input_data_x": [1, 320, 32, 32], + "input_data_y": [1, 320, 32, 32], + }, + ) + yield config, ['preln_groupnorm_act'], (3e-3, 1e-3) + + def sample_program_config(self, draw): + axis = draw(st.sampled_from([0, -1])) + epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001)) + batch_size = draw(st.integers(min_value=1, max_value=4)) + + groups = draw(st.sampled_from([4, 8, 16, 32])) + hw = draw(st.sampled_from([1, 8, 16, 32])) + channel = draw(st.sampled_from([320, 1280])) + + def generate_input_x(attrs): + return np.random.random( + [attrs[1]["batch_size"], *attrs[1]["input_dim_x"]] + ).astype(np.float32) + + def generate_input_y(attrs): + return np.random.random( + [attrs[1]["batch_size"], *attrs[1]["input_dim_y"]] + ).astype(np.float32) + + def generate_weight(attrs): + return np.random.random(attrs[1]['input_dim_x'][0]).astype( + np.float32 + ) + + attrs = [ + { + 'axis': axis, + 'epsilon': epsilon, + 'groups': groups, + }, + { + 'batch_size': batch_size, + 'input_dim_x': [channel, hw, hw], + 'input_dim_y': [channel, hw, hw], + }, + ] + + elementwise_add_op = OpConfig( + type="elementwise_add", + inputs={"X": ["input_data_x"], "Y": ["input_data_y"]}, + outputs={"Out": ["ele_out"]}, + attrs={"axis": attrs[0]['axis']}, + ) + group_norm_op = OpConfig( + type="group_norm", + inputs={ + "X": ["ele_out"], + "Bias": ["group_norm_bias"], + "Scale": ["group_norm_scale"], + }, + outputs={ + "Y": ["group_norm_output1"], + "Mean": ["group_norm_output2"], + "Variance": ["group_norm_output3"], + }, + attrs={ + "data_layout": "NCHW", + "groups": attrs[0]["groups"], + "epsilon": attrs[0]["epsilon"], + }, + ) + silu_op = OpConfig( + type="silu", + inputs={ + "X": ["group_norm_output1"], + }, + outputs={ + "Out": ["silu_output"], + }, + ) + + program_config = ProgramConfig( + ops=[ + elementwise_add_op, + group_norm_op, + silu_op, + ], + weights={ + "group_norm_bias": TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + "group_norm_scale": TensorConfig( + data_gen=partial(generate_weight, attrs) + ), + }, + inputs={ + "input_data_x": TensorConfig( + data_gen=partial(generate_input_x, attrs) + ), + "input_data_y": TensorConfig( + data_gen=partial(generate_input_y, attrs) + ), + }, + outputs=["ele_out", "silu_output"], + ) + + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=50, + passes=["preln_elementwise_groupnorm_act_pass"], + max_duration=250, + min_success_num=50, + ) + + +if __name__ == "__main__": + unittest.main()