提交 0fe3079c 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: fix for order of parameters in elementwise_add in resnet50

test=develop
上级 b73b8683
...@@ -109,7 +109,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { ...@@ -109,7 +109,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
} }
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"}); SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"}); SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
SetOp(&prog, "relu", {"d"}, {"e"}); SetOp(&prog, "relu", {"d"}, {"e"});
return prog; return prog;
...@@ -160,7 +160,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { ...@@ -160,7 +160,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
} }
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"}); SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"}); SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
return prog; return prog;
}; };
...@@ -211,7 +211,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { ...@@ -211,7 +211,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
SetOp(&prog, "sigmoid", {"a"}, {"b"}); SetOp(&prog, "sigmoid", {"a"}, {"b"});
SetOp(&prog, "conv2d", {"b", "bias", "weights"}, {"c"}); SetOp(&prog, "conv2d", {"b", "bias", "weights"}, {"c"});
SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"}); SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"});
SetOp(&prog, "relu", {"e"}, {"f"}); SetOp(&prog, "relu", {"e"}, {"f"});
return prog; return prog;
......
...@@ -1024,15 +1024,15 @@ PDNode *patterns::Conv::operator()() { ...@@ -1024,15 +1024,15 @@ PDNode *patterns::Conv::operator()() {
return output_var; return output_var;
} }
PDNode *patterns::ElementwiseAdd::operator()(PDNode *y_var) { PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var) {
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr()) auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add"); ->assert_is_op("elementwise_add");
auto x_var = pattern->NewNode(elementwise_add_x_repr()) x_var->assert_is_op_input("elementwise_add", "X");
->AsInput()
->assert_is_op_input("elementwise_add", "X");
y_var->assert_is_op_input("elementwise_add", "Y"); auto y_var = pattern->NewNode(elementwise_add_x_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto out_var = pattern->NewNode(elementwise_add_out_repr()) auto out_var = pattern->NewNode(elementwise_add_out_repr())
->AsOutput() ->AsOutput()
......
...@@ -631,7 +631,7 @@ struct ElementwiseAdd : public PatternBase { ...@@ -631,7 +631,7 @@ struct ElementwiseAdd : public PatternBase {
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope) ElementwiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add") {} : PatternBase(pattern, name_scope, "elementwise_add") {}
PDNode* operator()(PDNode* y_var); PDNode* operator()(PDNode* x_var);
PATTERN_DECL_NODE(elementwise_add_op); PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_x); PATTERN_DECL_NODE(elementwise_add_x);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册