diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index 08c3b23cf30d072035997b3027df15a9bea28e9c..fd47b96c101e94ce9d510fa8720151717b65c60b 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -109,7 +109,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { } SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"}); - SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"}); + SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"}); SetOp(&prog, "relu", {"d"}, {"e"}); return prog; @@ -160,7 +160,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { } SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"}); - SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"}); + SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"}); return prog; }; @@ -211,7 +211,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { SetOp(&prog, "sigmoid", {"a"}, {"b"}); SetOp(&prog, "conv2d", {"b", "bias", "weights"}, {"c"}); - SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"}); + SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"}); SetOp(&prog, "relu", {"e"}, {"f"}); return prog; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 6d524651e0b34dd917aec1f5a099dd5849a4a409..786765bff7214e6f42100c2ab508a9b018634ba2 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1024,15 +1024,15 @@ PDNode *patterns::Conv::operator()() { return output_var; } -PDNode *patterns::ElementwiseAdd::operator()(PDNode *y_var) { +PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var) { auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr()) ->assert_is_op("elementwise_add"); - auto x_var = pattern->NewNode(elementwise_add_x_repr()) - ->AsInput() - ->assert_is_op_input("elementwise_add", "X"); + x_var->assert_is_op_input("elementwise_add", "X"); - y_var->assert_is_op_input("elementwise_add", "Y"); + auto y_var = pattern->NewNode(elementwise_add_x_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add", "Y"); auto out_var = pattern->NewNode(elementwise_add_out_repr()) ->AsOutput() diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 08fd8174cebb9be15c603a08e4b2ac6298b31a9a..8e4f4a14ab78802c6dd275f698c278d05faae350 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -631,7 +631,7 @@ struct ElementwiseAdd : public PatternBase { ElementwiseAdd(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "elementwise_add") {} - PDNode* operator()(PDNode* y_var); + PDNode* operator()(PDNode* x_var); PATTERN_DECL_NODE(elementwise_add_op); PATTERN_DECL_NODE(elementwise_add_x);