diff --git a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc index 6d9611ebd13931a58215f4638c0886881bd38c51..737fa23f73732d76ce66021980ecd0d5e85d068d 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc @@ -131,8 +131,15 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input( "conv2d", "Input"); +#if CUDNN_VERSION >= 8000 + std::unordered_set cudnn_act_set( + {"identity", "relu", "sigmoid", "tanh"}); +#else + std::unordered_set cudnn_act_set({"identity", "relu"}); +#endif + patterns::ConvElementwiseadd2Act pattern(gpd.mutable_pattern(), pattern_name); - pattern(x); + pattern(x, cudnn_act_set); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index 47e2c5e380bcbfb6410e81cba9b2f7896f4018ae..1d309d133795c5d7f7ccceb3e0177b41c37b1246 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -130,8 +130,15 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { ->assert_is_op_input("conv2d", "Input") ->AsInput(); +#if CUDNN_VERSION >= 8000 + std::unordered_set cudnn_act_set( + {"identity", "relu", "sigmoid", "tanh"}); +#else + std::unordered_set cudnn_act_set({"identity", "relu"}); +#endif + patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name); - pattern(x); + pattern(x, cudnn_act_set); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index acbaef67a68fc47fe567b5ec1d950e32235292e0..dd5edaaa9c821f35f54877ad948aaaa28bbfb996 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2372,14 +2372,8 @@ PDNode *patterns::PriorBox::operator()() { return boxes_var; } -#if CUDNN_VERSION >= 8000 -std::unordered_set conv_act_set( - {"identity", "relu", "sigmoid", "tanh"}); -#else -std::unordered_set conv_act_set({"identity", "relu"}); -#endif - -PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { +PDNode *patterns::ConvElementwiseaddAct::operator()( + PDNode *conv_in, const std::unordered_set &conv_act_set) { conv_in->AsInput(); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_out = pattern->NewNode(conv_out_repr()) @@ -2576,7 +2570,8 @@ PDNode *patterns::VitAttention::operator()(PDNode *in) { return reshape2_out; } -PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) { +PDNode *patterns::ConvElementwiseadd2Act::operator()( + PDNode *conv_in, const std::unordered_set &conv_act_set) { auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_filter = pattern->NewNode(conv_filter_repr()) ->assert_is_op_input("conv2d", "Filter") diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index da479c1bf7c9b34195ea4785ec0723895c1269b7..f8f985fa5994ec015dee8cc2a01a8331b1461135 100755 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1476,7 +1476,8 @@ struct ConvElementwiseaddAct : public PatternBase { ConvElementwiseaddAct(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "conv_elementwiseadd_act") {} - PDNode* operator()(PDNode* conv_in); + PDNode* operator()(PDNode* conv_in, + const std::unordered_set& conv_act_set); PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(conv_out); @@ -1496,7 +1497,8 @@ struct ConvElementwiseadd2Act : public PatternBase { : PatternBase( pattern, name_scope, "conv_elementwiseadd2_elementwiseadd_act") {} - PDNode* operator()(PDNode* conv_in); + PDNode* operator()(PDNode* conv_in, + const std::unordered_set& conv_act_set); PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(conv_filter);