未验证 提交 adcb0039 编写于 作者: W wenbin 提交者: GitHub

more preln_gn patterns (#49728)

* compile fix

* fix compile

* compile fix

* add more preln
上级 a015f815
...@@ -35,7 +35,7 @@ struct PrelnGroupNormAct : public PatternBase { ...@@ -35,7 +35,7 @@ struct PrelnGroupNormAct : public PatternBase {
PrelnGroupNormAct(PDPattern *pattern, const std::string &name_scope) PrelnGroupNormAct(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_groupnorm_act") {} : PatternBase(pattern, name_scope, "preln_groupnorm_act") {}
void operator()(PDNode *x, PDNode *y); void operator()(PDNode *x, PDNode *y, bool with_act);
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(elementwise); PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(group_norm); PATTERN_DECL_NODE(group_norm);
...@@ -49,7 +49,7 @@ struct PrelnGroupNormAct : public PatternBase { ...@@ -49,7 +49,7 @@ struct PrelnGroupNormAct : public PatternBase {
PATTERN_DECL_NODE(act_out); PATTERN_DECL_NODE(act_out);
}; };
void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y) { void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y, bool with_act) {
auto *elementwise = auto *elementwise =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
...@@ -74,26 +74,28 @@ void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y) { ...@@ -74,26 +74,28 @@ void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y) {
auto *group_norm_out_var = pattern->NewNode(group_norm_out_repr()) auto *group_norm_out_var = pattern->NewNode(group_norm_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("group_norm", "Y") ->assert_is_op_output("group_norm", "Y");
->assert_is_op_input("silu", "X");
// Add links for group_norm op. // Add links for group_norm op.
group_norm group_norm
->LinksFrom( ->LinksFrom(
{elementwise_out_var, group_norm_bias_var, group_norm_scale_var}) {elementwise_out_var, group_norm_bias_var, group_norm_scale_var})
.LinksTo({group_norm_out_var}); .LinksTo({group_norm_out_var});
if (with_act) {
auto *act = pattern->NewNode(act_repr())->assert_is_op("silu"); group_norm_out_var->assert_is_op_input("silu", "X");
auto *act_out = pattern->NewNode(act_out_repr()) auto *act = pattern->NewNode(act_repr())->assert_is_op("silu");
->AsOutput() auto *act_out = pattern->NewNode(act_out_repr())
->assert_is_op_output("silu", "Out"); ->AsOutput()
->assert_is_op_output("silu", "Out");
act->LinksFrom({group_norm_out_var}).LinksTo({act_out});
act->LinksFrom({group_norm_out_var}).LinksTo({act_out});
}
} }
} // namespace patterns } // namespace patterns
int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { int PrelnGroupNormActFusePass::ApplyAddGNPattern(ir::Graph *graph,
bool with_act) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_groupnorm_silu_fuse", graph); FusePassBase::Init("preln_groupnorm_silu_fuse", graph);
...@@ -118,7 +120,7 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { ...@@ -118,7 +120,7 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
patterns::PrelnGroupNormAct fused_pattern(gpd.mutable_pattern(), patterns::PrelnGroupNormAct fused_pattern(gpd.mutable_pattern(),
"preln_groupnorm_act_fuse"); "preln_groupnorm_act_fuse");
fused_pattern(x, y); fused_pattern(x, y, with_act);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) { Graph *graph) {
...@@ -129,6 +131,9 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { ...@@ -129,6 +131,9 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
VLOG(4) << "handle preln groupnorm act fuse"; VLOG(4) << "handle preln groupnorm act fuse";
Node *act = nullptr;
Node *act_out = nullptr;
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm, group_norm, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(group_norm, group_norm, fused_pattern);
...@@ -136,8 +141,12 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { ...@@ -136,8 +141,12 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
group_norm_scale, group_norm_scale, fused_pattern); group_norm_scale, group_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm_out, group_norm_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(group_norm_out, group_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, fused_pattern); if (with_act) {
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(tmp_act, act, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(tmp_act_out, act_out, fused_pattern);
act = tmp_act;
act_out = tmp_act_out;
}
if (!IsCompat(subgraph, graph)) { if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "preln groupnorm act pass in op compat failed."; LOG(WARNING) << "preln groupnorm act pass in op compat failed.";
...@@ -150,8 +159,13 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { ...@@ -150,8 +159,13 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
new_desc.SetType("preln_groupnorm_act"); new_desc.SetType("preln_groupnorm_act");
new_desc.SetInput("X", {subgraph.at(x)->Name()}); new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()}); new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetAttr("with_silu", with_act);
new_desc.SetOutput("Out_0", {elementwise_out->Name()}); new_desc.SetOutput("Out_0", {elementwise_out->Name()});
new_desc.SetOutput("Out_1", {act_out->Name()}); if (with_act) {
new_desc.SetOutput("Out_1", {act_out->Name()});
} else {
new_desc.SetOutput("Out_1", {group_norm_out->Name()});
}
new_desc.RemoveOutput("Y"); new_desc.RemoveOutput("Y");
new_desc.Flush(); new_desc.Flush();
...@@ -159,15 +173,21 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { ...@@ -159,15 +173,21 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
del_node_set.insert(elementwise); del_node_set.insert(elementwise);
del_node_set.insert(group_norm); del_node_set.insert(group_norm);
del_node_set.insert(group_norm_out); if (with_act) {
del_node_set.insert(act); del_node_set.insert(act);
del_node_set.insert(group_norm_out);
}
GraphSafeRemoveNodes(graph, del_node_set); GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node); IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node); IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(group_norm_scale, fused_node); IR_NODE_LINK_TO(group_norm_scale, fused_node);
IR_NODE_LINK_TO(group_norm_bias, fused_node); IR_NODE_LINK_TO(group_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, act_out); if (with_act) {
IR_NODE_LINK_TO(fused_node, act_out);
} else {
IR_NODE_LINK_TO(fused_node, group_norm_out);
}
IR_NODE_LINK_TO(fused_node, elementwise_out); IR_NODE_LINK_TO(fused_node, elementwise_out);
found_subgraph_count++; found_subgraph_count++;
}; };
...@@ -178,7 +198,8 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const { ...@@ -178,7 +198,8 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
void PrelnGroupNormActFusePass::ApplyImpl(ir::Graph *graph) const { void PrelnGroupNormActFusePass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init("preln_groupnorm_act_fuse_pass", graph); FusePassBase::Init("preln_groupnorm_act_fuse_pass", graph);
int found_subgraph_count = ApplyGNSiluPattern(graph); int found_subgraph_count = ApplyAddGNPattern(graph, true);
found_subgraph_count += ApplyAddGNPattern(graph, false);
AddStatis(found_subgraph_count); AddStatis(found_subgraph_count);
} }
......
...@@ -25,7 +25,7 @@ namespace ir { ...@@ -25,7 +25,7 @@ namespace ir {
// | | -> preln_gn_act // | | -> preln_gn_act
// other op group_norm | | // other op group_norm | |
// | other op // | other op
// silu // silu(optional)
// | // |
class Graph; class Graph;
...@@ -88,7 +88,7 @@ class PrelnGroupNormActFusePass : public FusePassBase { ...@@ -88,7 +88,7 @@ class PrelnGroupNormActFusePass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
int ApplyGNSiluPattern(ir::Graph* graph) const; int ApplyAddGNPattern(ir::Graph* graph, bool with_act) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -45,6 +45,7 @@ class PrelnGroupnormActOpConverter : public OpConverter { ...@@ -45,6 +45,7 @@ class PrelnGroupnormActOpConverter : public OpConverter {
int groups = PADDLE_GET_CONST(int, op_desc.GetAttr("groups")); int groups = PADDLE_GET_CONST(int, op_desc.GetAttr("groups"));
float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool with_silu = PADDLE_GET_CONST(bool, op_desc.GetAttr("with_silu"));
std::string scale_name = op_desc.Input("Scale").front(); std::string scale_name = op_desc.Input("Scale").front();
std::string bias_name = op_desc.Input("Bias").front(); std::string bias_name = op_desc.Input("Bias").front();
...@@ -75,6 +76,7 @@ class PrelnGroupnormActOpConverter : public OpConverter { ...@@ -75,6 +76,7 @@ class PrelnGroupnormActOpConverter : public OpConverter {
bias_weights.get().count, bias_weights.get().count,
epsilon, epsilon,
groups, groups,
with_silu,
with_fp16); with_fp16);
nvinfer1::ILayer* groupnorm_layer = nvinfer1::ILayer* groupnorm_layer =
engine_->AddDynamicPlugin(inputs.data(), 2, plugin); engine_->AddDynamicPlugin(inputs.data(), 2, plugin);
......
...@@ -431,7 +431,7 @@ int PrelnGroupnormActPluginDynamic::enqueue( ...@@ -431,7 +431,7 @@ int PrelnGroupnormActPluginDynamic::enqueue(
if (cPerBlock > input_desc[0].dims.d[1]) { if (cPerBlock > input_desc[0].dims.d[1]) {
cPerBlock = 8; cPerBlock = 8;
} }
params_.withSwish = true; params_.withSwish = with_silu_;
params_.dst = static_cast<half *>(outputs[1]); params_.dst = static_cast<half *>(outputs[1]);
params_.eleOut = static_cast<half *>(outputs[0]); params_.eleOut = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]); params_.srcX = static_cast<half const *>(inputs[0]);
......
...@@ -36,6 +36,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { ...@@ -36,6 +36,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
const int bias_num, const int bias_num,
float eps, float eps,
int groups, int groups,
bool with_silu,
bool with_fp16, bool with_fp16,
std::shared_ptr<void> scale_gpu = nullptr, std::shared_ptr<void> scale_gpu = nullptr,
std::shared_ptr<void> bias_gpu = nullptr) std::shared_ptr<void> bias_gpu = nullptr)
...@@ -43,6 +44,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { ...@@ -43,6 +44,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
bias_gpu_(bias_gpu), bias_gpu_(bias_gpu),
groups_(groups), groups_(groups),
eps_(eps), eps_(eps),
with_silu_(with_silu),
with_fp16_(with_fp16) { with_fp16_(with_fp16) {
scale_.resize(scale_num); scale_.resize(scale_num);
bias_.resize(bias_num); bias_.resize(bias_num);
...@@ -69,6 +71,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { ...@@ -69,6 +71,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
DeserializeValue(&serialData, &serialLength, &bias_); DeserializeValue(&serialData, &serialLength, &bias_);
DeserializeValue(&serialData, &serialLength, &eps_); DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &groups_); DeserializeValue(&serialData, &serialLength, &groups_);
DeserializeValue(&serialData, &serialLength, &with_silu_);
DeserializeValue(&serialData, &serialLength, &with_fp16_); DeserializeValue(&serialData, &serialLength, &with_fp16_);
{ {
...@@ -97,6 +100,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { ...@@ -97,6 +100,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
bias_.size(), bias_.size(),
eps_, eps_,
groups_, groups_,
with_silu_,
with_fp16_, with_fp16_,
scale_gpu_, scale_gpu_,
bias_gpu_); bias_gpu_);
...@@ -112,13 +116,14 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { ...@@ -112,13 +116,14 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
size_t getSerializationSize() const TRT_NOEXCEPT override { size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(scale_) + SerializedSize(bias_) + return SerializedSize(scale_) + SerializedSize(bias_) +
SerializedSize(eps_) + SerializedSize(groups_) + SerializedSize(eps_) + SerializedSize(groups_) +
SerializedSize(with_fp16_); SerializedSize(with_silu_) + SerializedSize(with_fp16_);
} }
void serialize(void* buffer) const TRT_NOEXCEPT override { void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, scale_); SerializeValue(&buffer, scale_);
SerializeValue(&buffer, bias_); SerializeValue(&buffer, bias_);
SerializeValue(&buffer, eps_); SerializeValue(&buffer, eps_);
SerializeValue(&buffer, groups_); SerializeValue(&buffer, groups_);
SerializeValue(&buffer, with_silu_);
SerializeValue(&buffer, with_fp16_); SerializeValue(&buffer, with_fp16_);
} }
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
...@@ -171,6 +176,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { ...@@ -171,6 +176,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
GroupNormNHWCParams params_; GroupNormNHWCParams params_;
int groups_; int groups_;
float eps_; float eps_;
bool with_silu_;
bool with_fp16_; bool with_fp16_;
}; };
......
...@@ -169,5 +169,141 @@ class TestElementGNActPass(PassAutoScanTest): ...@@ -169,5 +169,141 @@ class TestElementGNActPass(PassAutoScanTest):
) )
class TestElementGNNoActPass(PassAutoScanTest):
#
# | | | |
# other_op1 other_op2 other_op1 other_op2
# | | fuse \ /
# elementwise_add -> preln_groupnorm_act
# | | | |
# other_op3 groupnorm other_op3
# |
#
def sample_predictor_configs(self, program_config):
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input_data_x": [1, 160, 1, 1],
"input_data_y": [1, 160, 1, 1],
},
{
"input_data_x": [4, 1280, 64, 64],
"input_data_y": [4, 1280, 64, 64],
},
{
"input_data_x": [1, 320, 32, 32],
"input_data_y": [1, 320, 32, 32],
},
)
yield config, ['preln_groupnorm_act'], (3e-3, 1e-3)
def sample_program_config(self, draw):
axis = draw(st.sampled_from([0, -1]))
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
batch_size = draw(st.integers(min_value=1, max_value=4))
groups = draw(st.sampled_from([4, 8, 16, 32]))
hw = draw(st.sampled_from([1, 8, 16, 32]))
channel = draw(st.sampled_from([320, 1280]))
def generate_input_x(attrs):
return np.random.random(
[attrs[1]["batch_size"], *attrs[1]["input_dim_x"]]
).astype(np.float32)
def generate_input_y(attrs):
return np.random.random(
[attrs[1]["batch_size"], *attrs[1]["input_dim_y"]]
).astype(np.float32)
def generate_weight(attrs):
return np.random.random(attrs[1]['input_dim_x'][0]).astype(
np.float32
)
attrs = [
{
'axis': axis,
'epsilon': epsilon,
'groups': groups,
},
{
'batch_size': batch_size,
'input_dim_x': [channel, hw, hw],
'input_dim_y': [channel, hw, hw],
},
]
elementwise_add_op = OpConfig(
type="elementwise_add",
inputs={"X": ["input_data_x"], "Y": ["input_data_y"]},
outputs={"Out": ["ele_out"]},
attrs={"axis": attrs[0]['axis']},
)
group_norm_op = OpConfig(
type="group_norm",
inputs={
"X": ["ele_out"],
"Bias": ["group_norm_bias"],
"Scale": ["group_norm_scale"],
},
outputs={
"Y": ["group_norm_output1"],
"Mean": ["group_norm_output2"],
"Variance": ["group_norm_output3"],
},
attrs={
"data_layout": "NCHW",
"groups": attrs[0]["groups"],
"epsilon": attrs[0]["epsilon"],
},
)
program_config = ProgramConfig(
ops=[
elementwise_add_op,
group_norm_op,
],
weights={
"group_norm_bias": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
"group_norm_scale": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
},
inputs={
"input_data_x": TensorConfig(
data_gen=partial(generate_input_x, attrs)
),
"input_data_y": TensorConfig(
data_gen=partial(generate_input_y, attrs)
),
},
outputs=["ele_out", "group_norm_output1"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["preln_elementwise_groupnorm_act_pass"],
max_duration=250,
min_success_num=50,
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册