未验证 提交 111075a3 编写于 作者: W wenbin 提交者: GitHub

gn_silu (#49928)

* gn_silu

* add ut

* set TIMEOUT

* correct comments

* comments

* disable windows ut

* rename parameter
上级 b0ee022b
...@@ -144,6 +144,7 @@ if(WITH_TENSORRT) ...@@ -144,6 +144,7 @@ if(WITH_TENSORRT)
pass_library(trt_support_nhwc_pass inference) pass_library(trt_support_nhwc_pass inference)
pass_library(elementwise_groupnorm_act_pass inference) pass_library(elementwise_groupnorm_act_pass inference)
pass_library(preln_elementwise_groupnorm_act_pass inference) pass_library(preln_elementwise_groupnorm_act_pass inference)
pass_library(groupnorm_act_pass inference)
pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference) pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
endif() endif()
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/groupnorm_act_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct GroupNormAct : public PatternBase {
GroupNormAct(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "groupnorm_act") {}
void operator()(PDNode *x);
// declare operator node's name
PATTERN_DECL_NODE(group_norm);
// declare variable node's name
PATTERN_DECL_NODE(elementwise_out);
PATTERN_DECL_NODE(group_norm_bias);
PATTERN_DECL_NODE(group_norm_scale);
PATTERN_DECL_NODE(group_norm_out);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(act_out);
};
void GroupNormAct::operator()(PDNode *x) {
// Create nodes for group_norm op.
auto *group_norm =
pattern->NewNode(group_norm_repr())->assert_is_op("group_norm");
auto *group_norm_bias_var = pattern->NewNode(group_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("group_norm", "Bias");
auto *group_norm_scale_var = pattern->NewNode(group_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("group_norm", "Scale");
auto *group_norm_out_var = pattern->NewNode(group_norm_out_repr())
->AsOutput()
->assert_is_op_output("group_norm", "Y")
->assert_is_op_input("silu", "X");
// Add links for group_norm op.
group_norm->LinksFrom({x, group_norm_bias_var, group_norm_scale_var})
.LinksTo({group_norm_out_var});
auto *act = pattern->NewNode(act_repr())->assert_is_op("silu");
auto *act_out = pattern->NewNode(act_out_repr())
->AsOutput()
->assert_is_op_output("silu", "Out");
act->LinksFrom({group_norm_out_var}).LinksTo({act_out});
}
} // namespace patterns
int GroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("groupnorm_silu_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
PDNode *x = nullptr;
x = gpd.mutable_pattern()
->NewNode("groupnorm_act_fuse/x")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("group_norm", "X");
patterns::GroupNormAct fused_pattern(gpd.mutable_pattern(),
"groupnorm_act_fuse");
fused_pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
VLOG(4) << "handle groupnorm act fuse";
GET_IR_NODE_FROM_SUBGRAPH(group_norm, group_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm_bias, group_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
group_norm_scale, group_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm_out, group_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fused_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "groupnorm act pass in op compat failed.";
return;
}
std::unordered_set<const Node *> del_node_set;
// Create an skip_groupnorm_act op node
OpDesc new_desc(*group_norm->Op());
new_desc.SetAttr("with_silu", true);
new_desc.SetOutput("Y", {act_out->Name()});
new_desc.Flush();
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(group_norm);
del_node_set.insert(group_norm_out);
del_node_set.insert(act);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(group_norm_scale, fused_node);
IR_NODE_LINK_TO(group_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, act_out);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void GroupNormActFusePass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init("groupnorm_act_fuse_pass", graph);
int found_subgraph_count = ApplyGNSiluPattern(graph);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(groupnorm_act_pass, paddle::framework::ir::GroupNormActFusePass);
REGISTER_PASS_CAPABILITY(groupnorm_act_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("silu", 0)
.EQ("group_norm", 0));
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
//
// | |
// group_norm group_norm
// | -> |
// silu
// |
class Graph;
class GroupNormActFusePass : public FusePassBase {
public:
GroupNormActFusePass() {
AddOpCompat(OpCompat("group_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(1.0f)
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("data_layout")
.IsStringIn({"NCHW"})
.End();
AddOpCompat(OpCompat("silu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
virtual ~GroupNormActFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
int ApplyGNSiluPattern(ir::Graph* graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -136,6 +136,7 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -136,6 +136,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
#else #else
"elementwise_groupnorm_act_pass", // "elementwise_groupnorm_act_pass", //
"preln_elementwise_groupnorm_act_pass", // "preln_elementwise_groupnorm_act_pass", //
"groupnorm_act_pass", //
#endif #endif
"tensorrt_subgraph_pass", // "tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
......
...@@ -46,6 +46,11 @@ class GroupNormOpConverter : public OpConverter { ...@@ -46,6 +46,11 @@ class GroupNormOpConverter : public OpConverter {
std::string scale_name = op_desc.Input("Scale").front(); std::string scale_name = op_desc.Input("Scale").front();
std::string bias_name = op_desc.Input("Bias").front(); std::string bias_name = op_desc.Input("Bias").front();
bool with_silu = false;
if (op_desc.HasAttr("with_silu")) {
with_silu = PADDLE_GET_CONST(bool, op_desc.GetAttr("with_silu"));
}
// get the presistable var's data // get the presistable var's data
auto GetWeight = [&](const std::string& var_name, auto GetWeight = [&](const std::string& var_name,
framework::DDim* dims) -> TensorRTEngine::Weight { framework::DDim* dims) -> TensorRTEngine::Weight {
...@@ -77,6 +82,7 @@ class GroupNormOpConverter : public OpConverter { ...@@ -77,6 +82,7 @@ class GroupNormOpConverter : public OpConverter {
groups, groups,
mean_shape, mean_shape,
variance_shape, variance_shape,
with_silu,
with_fp16); with_fp16);
nvinfer1::ILayer* groupnorm_layer = nvinfer1::ILayer* groupnorm_layer =
engine_->AddDynamicPlugin(&input_itensor, 1, plugin); engine_->AddDynamicPlugin(&input_itensor, 1, plugin);
......
...@@ -49,8 +49,8 @@ struct GroupNormNHWCParams { ...@@ -49,8 +49,8 @@ struct GroupNormNHWCParams {
int32_t c; int32_t c;
// The number of groups. // The number of groups.
int32_t groups; int32_t groups;
// Do we apply the Swish activation function? // Do we apply the Silu activation function?
bool withSwish; bool withSilu;
// Precomputed values and parameters to control the execution of the kernels. // Precomputed values and parameters to control the execution of the kernels.
......
...@@ -247,8 +247,8 @@ __global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) { ...@@ -247,8 +247,8 @@ __global__ void groupNormNHWCScaleKernel(const GroupNormNHWCParams params) {
f2.x = gammaF2.x * f2.x + betaF2.x; f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y; f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Swish if needed. // Apply Silu if needed.
if (params.withSwish) { if (params.withSilu) {
f2.x = f2.x * sigmoid(f2.x); f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y); f2.y = f2.y * sigmoid(f2.y);
} }
...@@ -457,7 +457,7 @@ bool GroupNormPluginDynamic::supportsFormatCombination( ...@@ -457,7 +457,7 @@ bool GroupNormPluginDynamic::supportsFormatCombination(
if (pos == 0) { if (pos == 0) {
if (with_fp16_) { if (with_fp16_) {
return ((in.type == nvinfer1::DataType::kHALF) && return ((in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::PluginFormat::kLINEAR || ((!with_silu_ && in.format == nvinfer1::PluginFormat::kLINEAR) ||
in.format == nvinfer1::PluginFormat::kHWC8)); in.format == nvinfer1::PluginFormat::kHWC8));
} else { } else {
return (in.type == nvinfer1::DataType::kFLOAT) && return (in.type == nvinfer1::DataType::kFLOAT) &&
...@@ -624,7 +624,7 @@ int GroupNormPluginDynamic::enqueue( ...@@ -624,7 +624,7 @@ int GroupNormPluginDynamic::enqueue(
cPerBlock = 8; cPerBlock = 8;
} }
params_.withSwish = false; params_.withSilu = with_silu_;
params_.dst = static_cast<half *>(outputs[0]); params_.dst = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]); params_.srcX = static_cast<half const *>(inputs[0]);
params_.gamma = scale_gpu_; params_.gamma = scale_gpu_;
......
...@@ -164,11 +164,13 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -164,11 +164,13 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
int groups, int groups,
std::vector<int64_t> mean_shape, std::vector<int64_t> mean_shape,
std::vector<int64_t> variance_shape, std::vector<int64_t> variance_shape,
bool with_silu,
bool with_fp16) bool with_fp16)
: groups_(groups), : groups_(groups),
eps_(eps), eps_(eps),
mean_shape_(mean_shape), mean_shape_(mean_shape),
variance_shape_(variance_shape), variance_shape_(variance_shape),
with_silu_(with_silu),
with_fp16_(with_fp16) { with_fp16_(with_fp16) {
scale_.resize(scale_num); scale_.resize(scale_num);
bias_.resize(bias_num); bias_.resize(bias_num);
...@@ -183,6 +185,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -183,6 +185,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
DeserializeValue(&serialData, &serialLength, &groups_); DeserializeValue(&serialData, &serialLength, &groups_);
DeserializeValue(&serialData, &serialLength, &mean_shape_); DeserializeValue(&serialData, &serialLength, &mean_shape_);
DeserializeValue(&serialData, &serialLength, &variance_shape_); DeserializeValue(&serialData, &serialLength, &variance_shape_);
DeserializeValue(&serialData, &serialLength, &with_silu_);
DeserializeValue(&serialData, &serialLength, &with_fp16_); DeserializeValue(&serialData, &serialLength, &with_fp16_);
} }
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
...@@ -194,6 +197,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -194,6 +197,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
groups_, groups_,
mean_shape_, mean_shape_,
variance_shape_, variance_shape_,
with_silu_,
with_fp16_); with_fp16_);
ptr->scale_gpu_ = scale_gpu_; ptr->scale_gpu_ = scale_gpu_;
ptr->bias_gpu_ = bias_gpu_; ptr->bias_gpu_ = bias_gpu_;
...@@ -210,7 +214,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -210,7 +214,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
return SerializedSize(scale_) + SerializedSize(bias_) + return SerializedSize(scale_) + SerializedSize(bias_) +
SerializedSize(eps_) + SerializedSize(groups_) + SerializedSize(eps_) + SerializedSize(groups_) +
SerializedSize(mean_shape_) + SerializedSize(variance_shape_) + SerializedSize(mean_shape_) + SerializedSize(variance_shape_) +
SerializedSize(with_fp16_); SerializedSize(with_silu_) + SerializedSize(with_fp16_);
} }
void serialize(void* buffer) const TRT_NOEXCEPT override { void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, scale_); SerializeValue(&buffer, scale_);
...@@ -219,6 +223,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -219,6 +223,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
SerializeValue(&buffer, groups_); SerializeValue(&buffer, groups_);
SerializeValue(&buffer, mean_shape_); SerializeValue(&buffer, mean_shape_);
SerializeValue(&buffer, variance_shape_); SerializeValue(&buffer, variance_shape_);
SerializeValue(&buffer, with_silu_);
SerializeValue(&buffer, with_fp16_); SerializeValue(&buffer, with_fp16_);
} }
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
...@@ -277,6 +282,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -277,6 +282,7 @@ class GroupNormPluginDynamic : public DynamicPluginTensorRT {
std::vector<int64_t> mean_shape_; std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_; std::vector<int64_t> variance_shape_;
GroupNormNHWCParams params_; GroupNormNHWCParams params_;
bool with_silu_;
bool with_fp16_; bool with_fp16_;
}; };
class GroupNormPluginDynamicCreator : public TensorRTPluginCreator { class GroupNormPluginDynamicCreator : public TensorRTPluginCreator {
......
...@@ -330,8 +330,8 @@ __global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { ...@@ -330,8 +330,8 @@ __global__ void prelnGroupNormNHWCScaleKernel(GroupNormNHWCParams params) {
f2.x = gammaF2.x * f2.x + betaF2.x; f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y; f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Swish if needed. // Apply Silu if needed.
if (params.withSwish) { if (params.withSilu) {
f2.x = f2.x * sigmoid(f2.x); f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y); f2.y = f2.y * sigmoid(f2.y);
} }
...@@ -431,7 +431,7 @@ int PrelnGroupnormActPluginDynamic::enqueue( ...@@ -431,7 +431,7 @@ int PrelnGroupnormActPluginDynamic::enqueue(
if (cPerBlock > input_desc[0].dims.d[1]) { if (cPerBlock > input_desc[0].dims.d[1]) {
cPerBlock = 8; cPerBlock = 8;
} }
params_.withSwish = with_silu_; params_.withSilu = with_silu_;
params_.dst = static_cast<half *>(outputs[1]); params_.dst = static_cast<half *>(outputs[1]);
params_.eleOut = static_cast<half *>(outputs[0]); params_.eleOut = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]); params_.srcX = static_cast<half const *>(inputs[0]);
......
...@@ -340,8 +340,8 @@ __global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) { ...@@ -340,8 +340,8 @@ __global__ void skipGroupNormNHWCScaleKernel(GroupNormNHWCParams params) {
f2.x = gammaF2.x * f2.x + betaF2.x; f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y; f2.y = gammaF2.y * f2.y + betaF2.y;
// Apply Swish if needed. // Apply Silu if needed.
if (params.withSwish) { if (params.withSilu) {
f2.x = f2.x * sigmoid(f2.x); f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y); f2.y = f2.y * sigmoid(f2.y);
} }
...@@ -439,7 +439,7 @@ int SkipGroupnormActPluginDynamic::enqueue( ...@@ -439,7 +439,7 @@ int SkipGroupnormActPluginDynamic::enqueue(
if (cPerBlock > input_desc[0].dims.d[1]) { if (cPerBlock > input_desc[0].dims.d[1]) {
cPerBlock = 8; cPerBlock = 8;
} }
params_.withSwish = true; params_.withSilu = true;
params_.dst = static_cast<half *>(outputs[0]); params_.dst = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]); params_.srcX = static_cast<half const *>(inputs[0]);
params_.srcY = static_cast<half const *>(inputs[1]); params_.srcY = static_cast<half const *>(inputs[1]);
......
...@@ -39,6 +39,7 @@ if(WIN32) ...@@ -39,6 +39,7 @@ if(WIN32)
"test_preln_groupnorm_act_fuse_pass") "test_preln_groupnorm_act_fuse_pass")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_element_groupnorm_act_fuse_pass") "test_element_groupnorm_act_fuse_pass")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_groupnorm_act_pass_fuse_pass")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_fused_token_prune") list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_fused_token_prune")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_fused_token_prune") list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_fused_token_prune")
endif() endif()
...@@ -225,6 +226,8 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -225,6 +226,8 @@ if(WITH_GPU AND TENSORRT_FOUND)
PROPERTIES TIMEOUT 120) PROPERTIES TIMEOUT 120)
set_tests_properties(test_preln_groupnorm_act_fuse_pass PROPERTIES TIMEOUT set_tests_properties(test_preln_groupnorm_act_fuse_pass PROPERTIES TIMEOUT
120) 120)
set_tests_properties(test_groupnorm_act_pass_fuse_pass PROPERTIES TIMEOUT
120)
endif() endif()
endif() endif()
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
import paddle.inference as paddle_infer
class TestElementGNActPass(PassAutoScanTest):
#
# | fuse |
# groupnorm -> groupnorm(with_silu)
# | |
# silu
# |
#
#
def sample_predictor_configs(self, program_config):
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input_data": [1, 160, 1, 1],
},
{
"input_data": [4, 1280, 64, 64],
},
{
"input_data": [1, 320, 32, 32],
},
)
yield config, ['group_norm'], (3e-3, 1e-3)
def sample_program_config(self, draw):
axis = draw(st.sampled_from([0, -1]))
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
batch_size = draw(st.integers(min_value=1, max_value=4))
groups = draw(st.sampled_from([4, 8, 16, 32]))
hw = draw(st.sampled_from([1, 8, 16, 32]))
channel = draw(st.sampled_from([320, 1280]))
def generate_input(attrs):
return np.random.random(
[attrs[1]["batch_size"], *attrs[1]["input_dim"]]
).astype(np.float32)
def generate_weight(attrs):
return np.random.random(attrs[1]['input_dim'][0]).astype(np.float32)
attrs = [
{
'epsilon': epsilon,
'groups': groups,
},
{
'batch_size': batch_size,
'input_dim': [channel, hw, hw],
},
]
group_norm_op = OpConfig(
type="group_norm",
inputs={
"X": ["input_data"],
"Bias": ["group_norm_bias"],
"Scale": ["group_norm_scale"],
},
outputs={
"Y": ["group_norm_output1"],
"Mean": ["group_norm_output2"],
"Variance": ["group_norm_output3"],
},
attrs={
"data_layout": "NCHW",
"groups": attrs[0]["groups"],
"epsilon": attrs[0]["epsilon"],
},
)
silu_op = OpConfig(
type="silu",
inputs={
"X": ["group_norm_output1"],
},
outputs={
"Out": ["silu_output"],
},
)
program_config = ProgramConfig(
ops=[
group_norm_op,
silu_op,
],
weights={
"group_norm_bias": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
"group_norm_scale": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input, attrs)
),
},
outputs=["silu_output"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["groupnorm_act_pass"],
max_duration=250,
min_success_num=50,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册