From d3f8ede01820e3e4b1a763df9d460cea0d56b142 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Thu, 1 Dec 2022 17:48:53 +0800 Subject: [PATCH] [Paddle Inference] remove conv_act_set from graph_pattern_detector.cc (#48569) * remove conv_act_set from graph_pattern_detector.cc --- .../ir/conv_elementwise_add2_act_fuse_pass.cc | 9 ++++++++- .../ir/conv_elementwise_add_act_fuse_pass.cc | 9 ++++++++- paddle/fluid/framework/ir/graph_pattern_detector.cc | 13 ++++--------- paddle/fluid/framework/ir/graph_pattern_detector.h | 6 ++++-- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index 6d9611ebd13..737fa23f737 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -131,8 +131,15 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input( "conv2d", "Input"); +#if CUDNN_VERSION >= 8000 + std::unordered_set cudnn_act_set( + {"identity", "relu", "sigmoid", "tanh"}); +#else + std::unordered_set cudnn_act_set({"identity", "relu"}); +#endif + patterns::ConvElementwiseadd2Act pattern(gpd.mutable_pattern(), pattern_name); - pattern(x); + pattern(x, cudnn_act_set); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index 47e2c5e380b..1d309d13379 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -130,8 +130,15 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { ->assert_is_op_input("conv2d", "Input") ->AsInput(); +#if CUDNN_VERSION >= 8000 + std::unordered_set cudnn_act_set( + {"identity", "relu", "sigmoid", "tanh"}); +#else + std::unordered_set cudnn_act_set({"identity", "relu"}); +#endif + patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name); - pattern(x); + pattern(x, cudnn_act_set); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index acbaef67a68..dd5edaaa9c8 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2372,14 +2372,8 @@ PDNode *patterns::PriorBox::operator()() { return boxes_var; } -#if CUDNN_VERSION >= 8000 -std::unordered_set conv_act_set( - {"identity", "relu", "sigmoid", "tanh"}); -#else -std::unordered_set conv_act_set({"identity", "relu"}); -#endif - -PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { +PDNode *patterns::ConvElementwiseaddAct::operator()( + PDNode *conv_in, const std::unordered_set &conv_act_set) { conv_in->AsInput(); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_out = pattern->NewNode(conv_out_repr()) @@ -2576,7 +2570,8 @@ PDNode *patterns::VitAttention::operator()(PDNode *in) { return reshape2_out; } -PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) { +PDNode *patterns::ConvElementwiseadd2Act::operator()( + PDNode *conv_in, const std::unordered_set &conv_act_set) { auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_filter = pattern->NewNode(conv_filter_repr()) ->assert_is_op_input("conv2d", "Filter") diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index da479c1bf7c..f8f985fa599 100755 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1476,7 +1476,8 @@ struct ConvElementwiseaddAct : public PatternBase { ConvElementwiseaddAct(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "conv_elementwiseadd_act") {} - PDNode* operator()(PDNode* conv_in); + PDNode* operator()(PDNode* conv_in, + const std::unordered_set& conv_act_set); PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(conv_out); @@ -1496,7 +1497,8 @@ struct ConvElementwiseadd2Act : public PatternBase { : PatternBase( pattern, name_scope, "conv_elementwiseadd2_elementwiseadd_act") {} - PDNode* operator()(PDNode* conv_in); + PDNode* operator()(PDNode* conv_in, + const std::unordered_set& conv_act_set); PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(conv_filter); -- GitLab