提交 cbe122ae 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: correcting formatting

上级 2a251bbf
...@@ -58,13 +58,13 @@ struct Conv { ...@@ -58,13 +58,13 @@ struct Conv {
auto conv_op = pattern->new_node(op_name())->assert_is_op(op_name()); auto conv_op = pattern->new_node(op_name())->assert_is_op(op_name());
auto input_var = pattern->new_node(input_name()) auto input_var = pattern->new_node(input_name())
->assert_is_op_input(op_name(), input_name()); ->assert_is_op_input(op_name(), input_name());
auto filter_var = pattern->new_node(filter_name()) auto filter_var = pattern->new_node(filter_name())
->assert_is_op_input(op_name(), filter_name()); ->assert_is_op_input(op_name(), filter_name());
auto output_var = pattern->new_node(output_name()) auto output_var = pattern->new_node(output_name())
->assert_is_op_output(op_name(), output_name()); ->assert_is_op_output(op_name(), output_name());
conv_op->LinksFrom({input_var, filter_var}); conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var}); conv_op->LinksTo({output_var});
...@@ -91,8 +91,8 @@ struct ElementwiseAdd { ...@@ -91,8 +91,8 @@ struct ElementwiseAdd {
conv_output->assert_is_op_input(op_name(), y_name()); conv_output->assert_is_op_input(op_name(), y_name());
auto out_var = pattern->new_node(out_name()) auto out_var = pattern->new_node(out_name())
->AsOutput() ->AsOutput()
->assert_is_op_output(op_name(), out_name()); ->assert_is_op_output(op_name(), out_name());
elementwise_add_op->LinksFrom({x_var, conv_output}); elementwise_add_op->LinksFrom({x_var, conv_output});
elementwise_add_op->LinksTo({out_var}); elementwise_add_op->LinksTo({out_var});
...@@ -179,15 +179,15 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -179,15 +179,15 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto fuse_conv = [&conv_pattern](Graph* g, Node* conv_input, auto fuse_conv = [&conv_pattern](Graph* g, Node* conv_input,
Node* conv_filter, Node* conv_filter, Node* conv_output,
Node* conv_output,
Node* elementwise_add_x) { Node* elementwise_add_x) {
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType(conv_pattern.op_name()); op_desc.SetType(conv_pattern.op_name());
op_desc.SetInput(conv_pattern.input_name(), {conv_input->Name()}); op_desc.SetInput(conv_pattern.input_name(), {conv_input->Name()});
op_desc.SetInput(conv_pattern.filter_name(), {conv_filter->Name()}); op_desc.SetInput(conv_pattern.filter_name(), {conv_filter->Name()});
op_desc.SetInput(conv_pattern.residual_data_name(), {elementwise_add_x->Name()}); op_desc.SetInput(conv_pattern.residual_data_name(),
{elementwise_add_x->Name()});
op_desc.SetOutput(conv_pattern.output_name(), {conv_output->Name()}); op_desc.SetOutput(conv_pattern.output_name(), {conv_output->Name()});
op_desc.SetAttr("use_mkldnn", true); op_desc.SetAttr("use_mkldnn", true);
...@@ -201,8 +201,9 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -201,8 +201,9 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
patterns::LinkNodes(fused_conv_op, conv_output); patterns::LinkNodes(fused_conv_op, conv_output);
}; };
auto handler = [&conv_pattern, &elementwise_add_pattern, pattern_ptr, fuse_conv] auto handler = [&conv_pattern, &elementwise_add_pattern, pattern_ptr,
(const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { fuse_conv](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.op_name()); conv_pattern.op_name());
auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册