未验证 提交 adcb0039 编写于 作者: W wenbin 提交者: GitHub

more preln_gn patterns (#49728)

* compile fix

* fix compile

* compile fix

* add more preln
上级 a015f815
......@@ -35,7 +35,7 @@ struct PrelnGroupNormAct : public PatternBase {
PrelnGroupNormAct(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_groupnorm_act") {}
void operator()(PDNode *x, PDNode *y);
void operator()(PDNode *x, PDNode *y, bool with_act);
// declare operator node's name
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(group_norm);
......@@ -49,7 +49,7 @@ struct PrelnGroupNormAct : public PatternBase {
PATTERN_DECL_NODE(act_out);
};
void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y) {
void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y, bool with_act) {
auto *elementwise =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
......@@ -74,26 +74,28 @@ void PrelnGroupNormAct::operator()(PDNode *x, PDNode *y) {
auto *group_norm_out_var = pattern->NewNode(group_norm_out_repr())
->AsOutput()
->assert_is_op_output("group_norm", "Y")
->assert_is_op_input("silu", "X");
->assert_is_op_output("group_norm", "Y");
// Add links for group_norm op.
group_norm
->LinksFrom(
{elementwise_out_var, group_norm_bias_var, group_norm_scale_var})
.LinksTo({group_norm_out_var});
auto *act = pattern->NewNode(act_repr())->assert_is_op("silu");
auto *act_out = pattern->NewNode(act_out_repr())
->AsOutput()
->assert_is_op_output("silu", "Out");
act->LinksFrom({group_norm_out_var}).LinksTo({act_out});
if (with_act) {
group_norm_out_var->assert_is_op_input("silu", "X");
auto *act = pattern->NewNode(act_repr())->assert_is_op("silu");
auto *act_out = pattern->NewNode(act_out_repr())
->AsOutput()
->assert_is_op_output("silu", "Out");
act->LinksFrom({group_norm_out_var}).LinksTo({act_out});
}
}
} // namespace patterns
int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
int PrelnGroupNormActFusePass::ApplyAddGNPattern(ir::Graph *graph,
bool with_act) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_groupnorm_silu_fuse", graph);
......@@ -118,7 +120,7 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
patterns::PrelnGroupNormAct fused_pattern(gpd.mutable_pattern(),
"preln_groupnorm_act_fuse");
fused_pattern(x, y);
fused_pattern(x, y, with_act);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
......@@ -129,6 +131,9 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
VLOG(4) << "handle preln groupnorm act fuse";
Node *act = nullptr;
Node *act_out = nullptr;
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm, group_norm, fused_pattern);
......@@ -136,8 +141,12 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
GET_IR_NODE_FROM_SUBGRAPH(
group_norm_scale, group_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(group_norm_out, group_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fused_pattern);
if (with_act) {
GET_IR_NODE_FROM_SUBGRAPH(tmp_act, act, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(tmp_act_out, act_out, fused_pattern);
act = tmp_act;
act_out = tmp_act_out;
}
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "preln groupnorm act pass in op compat failed.";
......@@ -150,8 +159,13 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
new_desc.SetType("preln_groupnorm_act");
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetAttr("with_silu", with_act);
new_desc.SetOutput("Out_0", {elementwise_out->Name()});
new_desc.SetOutput("Out_1", {act_out->Name()});
if (with_act) {
new_desc.SetOutput("Out_1", {act_out->Name()});
} else {
new_desc.SetOutput("Out_1", {group_norm_out->Name()});
}
new_desc.RemoveOutput("Y");
new_desc.Flush();
......@@ -159,15 +173,21 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
del_node_set.insert(elementwise);
del_node_set.insert(group_norm);
del_node_set.insert(group_norm_out);
del_node_set.insert(act);
if (with_act) {
del_node_set.insert(act);
del_node_set.insert(group_norm_out);
}
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(group_norm_scale, fused_node);
IR_NODE_LINK_TO(group_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, act_out);
if (with_act) {
IR_NODE_LINK_TO(fused_node, act_out);
} else {
IR_NODE_LINK_TO(fused_node, group_norm_out);
}
IR_NODE_LINK_TO(fused_node, elementwise_out);
found_subgraph_count++;
};
......@@ -178,7 +198,8 @@ int PrelnGroupNormActFusePass::ApplyGNSiluPattern(ir::Graph *graph) const {
void PrelnGroupNormActFusePass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init("preln_groupnorm_act_fuse_pass", graph);
int found_subgraph_count = ApplyGNSiluPattern(graph);
int found_subgraph_count = ApplyAddGNPattern(graph, true);
found_subgraph_count += ApplyAddGNPattern(graph, false);
AddStatis(found_subgraph_count);
}
......
......@@ -25,7 +25,7 @@ namespace ir {
// | | -> preln_gn_act
// other op group_norm | |
// | other op
// silu
// silu(optional)
// |
class Graph;
......@@ -88,7 +88,7 @@ class PrelnGroupNormActFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
int ApplyGNSiluPattern(ir::Graph* graph) const;
int ApplyAddGNPattern(ir::Graph* graph, bool with_act) const;
};
} // namespace ir
......
......@@ -45,6 +45,7 @@ class PrelnGroupnormActOpConverter : public OpConverter {
int groups = PADDLE_GET_CONST(int, op_desc.GetAttr("groups"));
float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool with_silu = PADDLE_GET_CONST(bool, op_desc.GetAttr("with_silu"));
std::string scale_name = op_desc.Input("Scale").front();
std::string bias_name = op_desc.Input("Bias").front();
......@@ -75,6 +76,7 @@ class PrelnGroupnormActOpConverter : public OpConverter {
bias_weights.get().count,
epsilon,
groups,
with_silu,
with_fp16);
nvinfer1::ILayer* groupnorm_layer =
engine_->AddDynamicPlugin(inputs.data(), 2, plugin);
......
......@@ -431,7 +431,7 @@ int PrelnGroupnormActPluginDynamic::enqueue(
if (cPerBlock > input_desc[0].dims.d[1]) {
cPerBlock = 8;
}
params_.withSwish = true;
params_.withSwish = with_silu_;
params_.dst = static_cast<half *>(outputs[1]);
params_.eleOut = static_cast<half *>(outputs[0]);
params_.srcX = static_cast<half const *>(inputs[0]);
......
......@@ -36,6 +36,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
const int bias_num,
float eps,
int groups,
bool with_silu,
bool with_fp16,
std::shared_ptr<void> scale_gpu = nullptr,
std::shared_ptr<void> bias_gpu = nullptr)
......@@ -43,6 +44,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
bias_gpu_(bias_gpu),
groups_(groups),
eps_(eps),
with_silu_(with_silu),
with_fp16_(with_fp16) {
scale_.resize(scale_num);
bias_.resize(bias_num);
......@@ -69,6 +71,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
DeserializeValue(&serialData, &serialLength, &bias_);
DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &groups_);
DeserializeValue(&serialData, &serialLength, &with_silu_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
{
......@@ -97,6 +100,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
bias_.size(),
eps_,
groups_,
with_silu_,
with_fp16_,
scale_gpu_,
bias_gpu_);
......@@ -112,13 +116,14 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(scale_) + SerializedSize(bias_) +
SerializedSize(eps_) + SerializedSize(groups_) +
SerializedSize(with_fp16_);
SerializedSize(with_silu_) + SerializedSize(with_fp16_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, bias_);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, groups_);
SerializeValue(&buffer, with_silu_);
SerializeValue(&buffer, with_fp16_);
}
nvinfer1::DimsExprs getOutputDimensions(
......@@ -171,6 +176,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT {
GroupNormNHWCParams params_;
int groups_;
float eps_;
bool with_silu_;
bool with_fp16_;
};
......
......@@ -169,5 +169,141 @@ class TestElementGNActPass(PassAutoScanTest):
)
class TestElementGNNoActPass(PassAutoScanTest):
#
# | | | |
# other_op1 other_op2 other_op1 other_op2
# | | fuse \ /
# elementwise_add -> preln_groupnorm_act
# | | | |
# other_op3 groupnorm other_op3
# |
#
def sample_predictor_configs(self, program_config):
# trt dynamic_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=1,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"input_data_x": [1, 160, 1, 1],
"input_data_y": [1, 160, 1, 1],
},
{
"input_data_x": [4, 1280, 64, 64],
"input_data_y": [4, 1280, 64, 64],
},
{
"input_data_x": [1, 320, 32, 32],
"input_data_y": [1, 320, 32, 32],
},
)
yield config, ['preln_groupnorm_act'], (3e-3, 1e-3)
def sample_program_config(self, draw):
axis = draw(st.sampled_from([0, -1]))
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
batch_size = draw(st.integers(min_value=1, max_value=4))
groups = draw(st.sampled_from([4, 8, 16, 32]))
hw = draw(st.sampled_from([1, 8, 16, 32]))
channel = draw(st.sampled_from([320, 1280]))
def generate_input_x(attrs):
return np.random.random(
[attrs[1]["batch_size"], *attrs[1]["input_dim_x"]]
).astype(np.float32)
def generate_input_y(attrs):
return np.random.random(
[attrs[1]["batch_size"], *attrs[1]["input_dim_y"]]
).astype(np.float32)
def generate_weight(attrs):
return np.random.random(attrs[1]['input_dim_x'][0]).astype(
np.float32
)
attrs = [
{
'axis': axis,
'epsilon': epsilon,
'groups': groups,
},
{
'batch_size': batch_size,
'input_dim_x': [channel, hw, hw],
'input_dim_y': [channel, hw, hw],
},
]
elementwise_add_op = OpConfig(
type="elementwise_add",
inputs={"X": ["input_data_x"], "Y": ["input_data_y"]},
outputs={"Out": ["ele_out"]},
attrs={"axis": attrs[0]['axis']},
)
group_norm_op = OpConfig(
type="group_norm",
inputs={
"X": ["ele_out"],
"Bias": ["group_norm_bias"],
"Scale": ["group_norm_scale"],
},
outputs={
"Y": ["group_norm_output1"],
"Mean": ["group_norm_output2"],
"Variance": ["group_norm_output3"],
},
attrs={
"data_layout": "NCHW",
"groups": attrs[0]["groups"],
"epsilon": attrs[0]["epsilon"],
},
)
program_config = ProgramConfig(
ops=[
elementwise_add_op,
group_norm_op,
],
weights={
"group_norm_bias": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
"group_norm_scale": TensorConfig(
data_gen=partial(generate_weight, attrs)
),
},
inputs={
"input_data_x": TensorConfig(
data_gen=partial(generate_input_x, attrs)
),
"input_data_y": TensorConfig(
data_gen=partial(generate_input_y, attrs)
),
},
outputs=["ele_out", "group_norm_output1"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["preln_elementwise_groupnorm_act_pass"],
max_duration=250,
min_success_num=50,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册