提交 42f569fd 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: use_mkldnn attribute added

上级 441d3a47
......@@ -39,25 +39,25 @@ struct Conv {
std::function<PDNode* ()> operator()(std::shared_ptr<Pattern> pattern) {
return [&]() -> PDNode* {
auto conv_op = pattern->new_node(op_name())
->assert_is_op("conv2d");
auto conv_op = pattern->new_node(op_name())
->assert_is_op("conv2d");
auto input_var = pattern->new_node(input_name())
->assert_is_op_input(op_name(),
input_name());
auto filter_var = pattern->new_node(filter_name())
->assert_is_op_input(op_name(),
filter_name());
auto input_var = pattern->new_node(input_name())
->assert_is_op_input(op_name(),
input_name());
auto filter_var = pattern->new_node(filter_name())
->assert_is_op_input(op_name(),
filter_name());
auto output_var = pattern->new_node(output_name())
->assert_is_op_output(op_name(),
output_name());
auto output_var = pattern->new_node(output_name())
->assert_is_op_output(op_name(),
output_name());
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var});
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var});
return output_var;
return output_var;
};
}
};
......@@ -139,7 +139,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate();
auto fuse_conv = [&](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
OpDesc op_desc;
op_desc.SetType("conv2d");
......@@ -147,7 +147,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetOutput("Output", {y->Name()});
op_desc.SetAttr("fuse_sum", true);
op_desc.SetAttr("use_mkldnn", true);
op_desc.SetAttr("fuse_eltwise", true);
auto fused_conv_op = g->CreateOpNode(&op_desc);
......@@ -175,7 +176,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
fuse_conv(g, conv_input, conv_filter, elementwise_add_y);
patterns::CorrectGraphEdges(g, elementwise_add_out, elementwise_add_y);
patterns::GraphSafeRemoveNodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
};
gpd(graph.get(), handler);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册