提交 8fb29b2c 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: new nodes marked as input or output

test=develop
上级 cc1c8e37
...@@ -1003,15 +1003,19 @@ PDNode *patterns::Conv::operator()() { ...@@ -1003,15 +1003,19 @@ PDNode *patterns::Conv::operator()() {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto input_var = pattern->NewNode(conv_input_repr()) auto input_var = pattern->NewNode(conv_input_repr())
->AsInput()
->assert_is_op_input("conv2d", "Input"); ->assert_is_op_input("conv2d", "Input");
auto bias_var = auto bias_var = pattern->NewNode(conv_bias_repr())
pattern->NewNode(conv_bias_repr())->assert_is_op_input("conv2d", "Bias"); ->AsInput()
->assert_is_op_input("conv2d", "Bias");
auto filter_var = pattern->NewNode(conv_filter_repr()) auto filter_var = pattern->NewNode(conv_filter_repr())
->AsInput()
->assert_is_op_input("conv2d", "Filter"); ->assert_is_op_input("conv2d", "Filter");
auto output_var = pattern->NewNode(conv_output_repr()) auto output_var = pattern->NewNode(conv_output_repr())
->AsOutput()
->assert_is_op_output("conv2d", "Output"); ->assert_is_op_output("conv2d", "Output");
conv_op->LinksFrom({input_var, bias_var, filter_var}); conv_op->LinksFrom({input_var, bias_var, filter_var});
...@@ -1025,6 +1029,7 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *y_var) { ...@@ -1025,6 +1029,7 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *y_var) {
->assert_is_op("elementwise_add"); ->assert_is_op("elementwise_add");
auto x_var = pattern->NewNode(elementwise_add_x_repr()) auto x_var = pattern->NewNode(elementwise_add_x_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "X"); ->assert_is_op_input("elementwise_add", "X");
y_var->assert_is_op_input("elementwise_add", "Y"); y_var->assert_is_op_input("elementwise_add", "Y");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册