提交 42f569fd 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: use_mkldnn attribute added

上级 441d3a47
......@@ -139,7 +139,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate();
auto fuse_conv = [&](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
OpDesc op_desc;
op_desc.SetType("conv2d");
......@@ -147,7 +147,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetOutput("Output", {y->Name()});
op_desc.SetAttr("fuse_sum", true);
op_desc.SetAttr("use_mkldnn", true);
op_desc.SetAttr("fuse_eltwise", true);
auto fused_conv_op = g->CreateOpNode(&op_desc);
......@@ -175,7 +176,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
fuse_conv(g, conv_input, conv_filter, elementwise_add_y);
patterns::CorrectGraphEdges(g, elementwise_add_out, elementwise_add_y);
patterns::GraphSafeRemoveNodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
};
gpd(graph.get(), handler);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册