提交 0fe3079c 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: fix for order of parameters in elementwise_add in resnet50

test=develop
上级 b73b8683
......@@ -109,7 +109,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
}
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
SetOp(&prog, "relu", {"d"}, {"e"});
return prog;
......@@ -160,7 +160,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
}
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
return prog;
};
......@@ -211,7 +211,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
SetOp(&prog, "sigmoid", {"a"}, {"b"});
SetOp(&prog, "conv2d", {"b", "bias", "weights"}, {"c"});
SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"});
SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"});
SetOp(&prog, "relu", {"e"}, {"f"});
return prog;
......
......@@ -1024,15 +1024,15 @@ PDNode *patterns::Conv::operator()() {
return output_var;
}
PDNode *patterns::ElementwiseAdd::operator()(PDNode *y_var) {
PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var) {
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
auto x_var = pattern->NewNode(elementwise_add_x_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "X");
x_var->assert_is_op_input("elementwise_add", "X");
y_var->assert_is_op_input("elementwise_add", "Y");
auto y_var = pattern->NewNode(elementwise_add_x_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto out_var = pattern->NewNode(elementwise_add_out_repr())
->AsOutput()
......
......@@ -631,7 +631,7 @@ struct ElementwiseAdd : public PatternBase {
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add") {}
PDNode* operator()(PDNode* y_var);
PDNode* operator()(PDNode* x_var);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_x);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册