未验证 提交 d3f8ede0 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference] remove conv_act_set from graph_pattern_detector.cc (#48569)

* remove conv_act_set from graph_pattern_detector.cc
上级 2bdad6cd
......@@ -131,8 +131,15 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
auto* x = gpd.mutable_pattern()->NewNode("x")->AsInput()->assert_is_op_input(
"conv2d", "Input");
#if CUDNN_VERSION >= 8000
std::unordered_set<std::string> cudnn_act_set(
{"identity", "relu", "sigmoid", "tanh"});
#else
std::unordered_set<std::string> cudnn_act_set({"identity", "relu"});
#endif
patterns::ConvElementwiseadd2Act pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
pattern(x, cudnn_act_set);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......
......@@ -130,8 +130,15 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
->assert_is_op_input("conv2d", "Input")
->AsInput();
#if CUDNN_VERSION >= 8000
std::unordered_set<std::string> cudnn_act_set(
{"identity", "relu", "sigmoid", "tanh"});
#else
std::unordered_set<std::string> cudnn_act_set({"identity", "relu"});
#endif
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
pattern(x, cudnn_act_set);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......
......@@ -2372,14 +2372,8 @@ PDNode *patterns::PriorBox::operator()() {
return boxes_var;
}
#if CUDNN_VERSION >= 8000
std::unordered_set<std::string> conv_act_set(
{"identity", "relu", "sigmoid", "tanh"});
#else
std::unordered_set<std::string> conv_act_set({"identity", "relu"});
#endif
PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
PDNode *patterns::ConvElementwiseaddAct::operator()(
PDNode *conv_in, const std::unordered_set<std::string> &conv_act_set) {
conv_in->AsInput();
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto conv_out = pattern->NewNode(conv_out_repr())
......@@ -2576,7 +2570,8 @@ PDNode *patterns::VitAttention::operator()(PDNode *in) {
return reshape2_out;
}
PDNode *patterns::ConvElementwiseadd2Act::operator()(PDNode *conv_in) {
PDNode *patterns::ConvElementwiseadd2Act::operator()(
PDNode *conv_in, const std::unordered_set<std::string> &conv_act_set) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto conv_filter = pattern->NewNode(conv_filter_repr())
->assert_is_op_input("conv2d", "Filter")
......
......@@ -1476,7 +1476,8 @@ struct ConvElementwiseaddAct : public PatternBase {
ConvElementwiseaddAct(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in);
PDNode* operator()(PDNode* conv_in,
const std::unordered_set<std::string>& conv_act_set);
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_out);
......@@ -1496,7 +1497,8 @@ struct ConvElementwiseadd2Act : public PatternBase {
: PatternBase(
pattern, name_scope, "conv_elementwiseadd2_elementwiseadd_act") {}
PDNode* operator()(PDNode* conv_in);
PDNode* operator()(PDNode* conv_in,
const std::unordered_set<std::string>& conv_act_set);
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_filter);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册