未验证 提交 d3f8ede0 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference] remove conv_act_set from graph_pattern_detector.cc (#48569)

* remove conv_act_set from graph_pattern_detector.cc
上级 2bdad6cd
...@@ -131,8 +131,15 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -131,8 +131,15 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input( auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
"conv2d", "Input"); "conv2d", "Input");
#if CUDNN_VERSION >= 8000
std::unordered_set<std::string> cudnn_act_set(
{"identity", "relu", "sigmoid", "tanh"});
#else
std::unordered_set<std::string> cudnn_act_set({"identity", "relu"});
#endif
patterns::ConvElementwiseadd2Act pattern(gpd.mutable_pattern(), pattern_name); patterns::ConvElementwiseadd2Act pattern(gpd.mutable_pattern(), pattern_name);
pattern(x); pattern(x, cudnn_act_set);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
......
...@@ -130,8 +130,15 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -130,8 +130,15 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
->assert_is_op_input("conv2d", "Input") ->assert_is_op_input("conv2d", "Input")
->AsInput(); ->AsInput();
#if CUDNN_VERSION >= 8000
std::unordered_set<std::string> cudnn_act_set(
{"identity", "relu", "sigmoid", "tanh"});
#else
std::unordered_set<std::string> cudnn_act_set({"identity", "relu"});
#endif
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name); patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x); pattern(x, cudnn_act_set);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
......
...@@ -2372,14 +2372,8 @@ PDNode *patterns::PriorBox::operator()() { ...@@ -2372,14 +2372,8 @@ PDNode *patterns::PriorBox::operator()() {
return boxes_var; return boxes_var;
} }
#if CUDNN_VERSION >= 8000 PDNode *patterns::ConvElementwiseaddAct::operator()(
std::unordered_set<std::string> conv_act_set( PDNode *conv_in, const std::unordered_set<std::string> &conv_act_set) {
{"identity", "relu", "sigmoid", "tanh"});
#else
std::unordered_set<std::string> conv_act_set({"identity", "relu"});
#endif
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
conv_in->AsInput(); conv_in->AsInput();
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto conv_out = pattern->NewNode(conv_out_repr()) auto conv_out = pattern->NewNode(conv_out_repr())
...@@ -2576,7 +2570,8 @@ PDNode *patterns::VitAttention::operator()(PDNode *in) { ...@@ -2576,7 +2570,8 @@ PDNode *patterns::VitAttention::operator()(PDNode *in) {
return reshape2_out; return reshape2_out;
} }
PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) { PDNode *patterns::ConvElementwiseadd2Act::operator()(
PDNode *conv_in, const std::unordered_set<std::string> &conv_act_set) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto conv_filter = pattern->NewNode(conv_filter_repr()) auto conv_filter = pattern->NewNode(conv_filter_repr())
->assert_is_op_input("conv2d", "Filter") ->assert_is_op_input("conv2d", "Filter")
......
...@@ -1476,7 +1476,8 @@ struct ConvElementwiseaddAct : public PatternBase { ...@@ -1476,7 +1476,8 @@ struct ConvElementwiseaddAct : public PatternBase {
ConvElementwiseaddAct(PDPattern* pattern, const std::string& name_scope) ConvElementwiseaddAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_elementwiseadd_act") {} : PatternBase(pattern, name_scope, "conv_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in); PDNode* operator()(PDNode* conv_in,
const std::unordered_set<std::string>& conv_act_set);
PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out); PATTERN_DECL_NODE(conv_out);
...@@ -1496,7 +1497,8 @@ struct ConvElementwiseadd2Act : public PatternBase { ...@@ -1496,7 +1497,8 @@ struct ConvElementwiseadd2Act : public PatternBase {
: PatternBase( : PatternBase(
pattern, name_scope, "conv_elementwiseadd2_elementwiseadd_act") {} pattern, name_scope, "conv_elementwiseadd2_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in); PDNode* operator()(PDNode* conv_in,
const std::unordered_set<std::string>& conv_act_set);
PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_filter); PATTERN_DECL_NODE(conv_filter);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册